diff --git a/AGENTS.md b/AGENTS.md index acf53a31..5f8206de 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -4,15 +4,34 @@ - We don't care about backwards compatibility because we are still in development. Keep the code simple and lean. - Avoid adding new Rust features or ENVs unless it is explicitly approved. - Never modify this file without explicit approval. +- No single file should ever exceed 1,500 lines of code unless explicitly confirmed by the user. ## Testing - Never add tests in the same implementation file, always prefer to add them to a file inside tests/ (current or new) - If you add a test to catch a problem, the test should fail if aims to confirm a problem. +- Always use `FoldingMode::Optimized` in tests. Never use `FoldingMode::PaperExact` unless the user explicitly approves it. PaperExact is an O(2^ell) brute-force reference engine meant only for correctness cross-checking, not general test usage. ## Build & Test Commands - When running tests use --release eg cargo test --workspace --release - For extra debugs use debug-logs eg --features paper-exact,debug-logs +## Perf & Constraint Debugging + +Perf tests live in `crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs`. All use `--ignored`. + +Full constraint architecture report (main CCS, bus, Route-A claims, openings, timing): +```bash +NS_DEBUG_N=10 cargo test -p neo-fold --release --test perf -- --ignored --nocapture report_track_a_w0_w1_snapshot +``` +N: number of riscv instructions + 1 (halt). + +Other useful tests (all accept `NS_DEBUG_N`): +- `debug_trace_single_n_mixed_ops` — trace-wiring prove/verify + openings +- `debug_chunked_single_n_mixed_ops` — same in chunked trace mode +- `debug_trace_vs_chunked_single_n_mixed_ops` — side-by-side comparison +- `report_trace_vs_chunked_medians` — 5-run median timing +- `debug_trace_core_rows_per_cycle_equiv` — CCS rows/cycle (no prove, fast; uses `NS_DEBUG_T`) + ## Profiling | Tool | Use Case | Output | diff --git a/Cargo.lock b/Cargo.lock index c83f1e4b..6b5901ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1231,6 +1231,8 @@ dependencies = [ "p3-field 0.4.1", "p3-goldilocks 0.4.1", "p3-matrix 0.4.1", + "rand 0.9.2", + "rand_chacha 0.9.0", "serde", "thiserror 2.0.17", ] diff --git a/crates/neo-ccs/src/r1cs.rs b/crates/neo-ccs/src/r1cs.rs index f094dc4d..5e673a38 100644 --- a/crates/neo-ccs/src/r1cs.rs +++ b/crates/neo-ccs/src/r1cs.rs @@ -7,7 +7,7 @@ use crate::{ }; /// Minimal **R1CS → CCS** helper: given A, B, C ∈ F^{n×m}, produce CCS with -/// M_1=A, M_2=B, M_3=C and f(X1,X2,X3) = X1·X2 − X3 (elementwise). +/// M_0=A, M_1=B, M_2=C and f(X0,X1,X2) = X0·X1 − X2 (elementwise). /// /// This is the standard embedding: row-wise, `A z ∘ B z = C z`, i.e., `f=0`. pub fn r1cs_to_ccs(a: Mat, b: Mat, c: Mat) -> CcsStructure { @@ -16,10 +16,7 @@ pub fn r1cs_to_ccs(a: Mat, b: Mat, c: Mat) -> CcsStructure assert_eq!(a.cols(), b.cols()); assert_eq!(a.cols(), c.cols()); - let n = a.rows(); - let m = a.cols(); - - // Base polynomial f(X1,X2,X3) = X1 * X2 - X3 + // Base polynomial f(X0,X1,X2) = X0 * X1 - X2 let base_terms = vec![ Term { coeff: F::ONE, @@ -32,12 +29,5 @@ pub fn r1cs_to_ccs(a: Mat, b: Mat, c: Mat) -> CcsStructure ]; let f_base = SparsePoly::new(3, base_terms); - // Insert identity-first only when square; otherwise keep legacy 3-matrix CCS. - if n == m { - let i_n = Mat::::identity(n); - let f = f_base.insert_var_at_front(); - CcsStructure::new(vec![i_n, a, b, c], f).expect("valid identity-first CCS structure") - } else { - CcsStructure::new(vec![a, b, c], f_base).expect("valid R1CS→CCS structure") - } + CcsStructure::new(vec![a, b, c], f_base).expect("valid R1CS→CCS structure") } diff --git a/crates/neo-ccs/tests/identity_validation_tests.rs b/crates/neo-ccs/tests/identity_validation_tests.rs index 1ba45aaa..b8b6be2a 100644 --- a/crates/neo-ccs/tests/identity_validation_tests.rs +++ b/crates/neo-ccs/tests/identity_validation_tests.rs @@ -1,4 +1,4 @@ -//! Tests for M₀ = I_n validation required by Ajtai/NC pipeline. +//! Tests for legacy identity-first validation helpers used by Ajtai/NC-specific paths. #![allow(non_snake_case)] @@ -7,7 +7,7 @@ use neo_ccs::{r1cs_to_ccs, CcsStructure, Mat}; use neo_math::F; use p3_field::PrimeCharacteristicRing; -/// Test that a valid square CCS with M₀ = I passes validation +/// Test that square R1CS output can be normalized to identity-first when needed. #[test] fn test_identity_validation_valid_square_ccs() { // Create a simple square R1CS (4x4) @@ -23,11 +23,13 @@ fn test_identity_validation_valid_square_ccs() { B[(0, 1)] = F::ONE; C[(0, 2)] = F::ONE; - // r1cs_to_ccs should produce identity-first CCS for square input + // r1cs_to_ccs always produces 3-matrix embedding now. let ccs = r1cs_to_ccs(A, B, C); + assert_eq!(ccs.matrices.len(), 3); - // Should pass validation - assert!(ccs.assert_m0_is_identity_for_nc().is_ok()); + // Explicit identity-first normalization still supports legacy validation paths. + let ccs_normalized = ccs.ensure_identity_first().expect("normalize"); + assert!(ccs_normalized.assert_m0_is_identity_for_nc().is_ok()); } /// Test that a non-square CCS fails validation with clear error @@ -157,20 +159,21 @@ fn test_happy_path_square_r1cs_to_validated_ccs() { B[(4, 2)] = F::ONE; C[(4, 2)] = F::ONE; - // Convert to CCS (should be identity-first for square) + // Convert to CCS (3-matrix embedding, no auto identity insertion). let ccs = r1cs_to_ccs(A, B, C); // Verify it's square assert_eq!(ccs.n, ccs.m); assert_eq!(ccs.n, n); - // Verify M₀ is identity - assert!(ccs.matrices[0].is_identity()); + // By default the first matrix is A, not identity. + assert!(!ccs.matrices[0].is_identity()); + assert_eq!(ccs.matrices.len(), 3); - // Validation should pass - assert!(ccs.assert_m0_is_identity_for_nc().is_ok()); + // Legacy validation path requires explicit normalization. + assert!(ccs.assert_m0_is_identity_for_nc().is_err()); - // ensure_identity_first should be a no-op + // ensure_identity_first produces identity-first form for square CCS. let ccs_normalized = ccs.ensure_identity_first().expect("normalize"); assert!(ccs_normalized.assert_m0_is_identity_for_nc().is_ok()); } diff --git a/crates/neo-ccs/tests/r1cs_embedding_shape_tests.rs b/crates/neo-ccs/tests/r1cs_embedding_shape_tests.rs new file mode 100644 index 00000000..1d5f8c87 --- /dev/null +++ b/crates/neo-ccs/tests/r1cs_embedding_shape_tests.rs @@ -0,0 +1,38 @@ +use neo_ccs::{r1cs_to_ccs, Mat}; +use neo_math::F; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn square_r1cs_uses_three_matrix_embedding() { + let n = 4usize; + let m = 4usize; + + let mut a = Mat::zero(n, m, F::ZERO); + let mut b = Mat::zero(n, m, F::ZERO); + let mut c = Mat::zero(n, m, F::ZERO); + a[(0, 0)] = F::ONE; + b[(0, 1)] = F::ONE; + c[(0, 2)] = F::ONE; + + let ccs = r1cs_to_ccs(a, b, c); + assert_eq!(ccs.t(), 3, "square R1CS must not auto-insert identity matrix"); + assert!(!ccs.matrices[0].is_identity(), "M0 should be A, not identity"); +} + +#[test] +fn rectangular_r1cs_uses_three_matrix_embedding() { + let n = 2usize; + let m = 5usize; + + let mut a = Mat::zero(n, m, F::ZERO); + let mut b = Mat::zero(n, m, F::ZERO); + let mut c = Mat::zero(n, m, F::ZERO); + a[(0, 0)] = F::ONE; + b[(0, 1)] = F::ONE; + c[(0, 2)] = F::ONE; + + let ccs = r1cs_to_ccs(a, b, c); + assert_eq!(ccs.t(), 3); + assert_eq!(ccs.n, n); + assert_eq!(ccs.m, m); +} diff --git a/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_full_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_full_prove_verify.rs index 82f7dda5..a25b68fa 100644 --- a/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_full_prove_verify.rs +++ b/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_full_prove_verify.rs @@ -1,237 +1,72 @@ -//! End-to-end prove+verify for a small *compiled* RV32 guest program under the B1 shared-bus step circuit. -//! -//! The guest is authored in Rust under `riscv-tests/guests/rv32-fibonacci/` and its ROM bytes are committed in -//! `riscv-tests/binaries/rv32_fibonacci_rom.rs` so this test doesn't need to cross-compile at runtime. +//! End-to-end prove+verify for a small Fibonacci-style RV32 program under trace wiring. #![allow(non_snake_case)] -#[path = "binaries/rv32_fibonacci_rom.rs"] -mod rv32_fibonacci_rom; - -use neo_fold::riscv_shard::Rv32B1; +use neo_fold::riscv_trace_shard::Rv32TraceWiring; use neo_math::F; -use p3_field::{PrimeCharacteristicRing, PrimeField64}; -use sha2::{Digest, Sha256}; - -fn hex32(bytes: [u8; 32]) -> String { - let mut out = String::with_capacity(64); - for b in bytes { - out.push_str(&format!("{:02x}", b)); - } - out -} - -fn sha256_field_slice(values: &[F]) -> [u8; 32] { - let mut hasher = Sha256::new(); - for v in values { - hasher.update(v.as_canonical_u64().to_le_bytes()); - } - hasher.finalize().into() -} - -fn sha256_fields_concat(x: &[F], w: &[F]) -> [u8; 32] { - let mut hasher = Sha256::new(); - for v in x { - hasher.update(v.as_canonical_u64().to_le_bytes()); - } - for v in w { - hasher.update(v.as_canonical_u64().to_le_bytes()); - } - hasher.finalize().into() -} - -fn nonzero_concat(x: &[F], w: &[F]) -> usize { - x.iter().chain(w.iter()).filter(|v| **v != F::ZERO).count() -} - -fn preview_first_last(values: &[F], n: usize) -> (Vec, Vec) { - let n = n.min(values.len()); - if n == 0 { - return (Vec::new(), Vec::new()); - } - let first = values - .iter() - .take(n) - .map(|v| v.as_canonical_u64()) - .collect::>(); - let last = values - .iter() - .skip(values.len().saturating_sub(n)) - .map(|v| v.as_canonical_u64()) - .collect::>(); - (first, last) -} - -fn preview_first_last_concat(x: &[F], w: &[F], n: usize) -> (Vec, Vec) { - let z_len = x.len() + w.len(); - let n = n.min(z_len); - if n == 0 { - return (Vec::new(), Vec::new()); - } - - let mut first = Vec::with_capacity(n); - for v in x.iter().chain(w.iter()).take(n) { - first.push(v.as_canonical_u64()); - } - - let last = if w.len() >= n { - w.iter() - .skip(w.len() - n) - .map(|v| v.as_canonical_u64()) - .collect::>() - } else { - let need_from_x = n - w.len(); - x.iter() - .skip(x.len().saturating_sub(need_from_x)) - .chain(w.iter()) - .map(|v| v.as_canonical_u64()) - .collect::>() - }; - - (first, last) -} +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; #[test] fn test_riscv_fibonacci_compiled_full_prove_verify() { - // The guest reads n from RAM[0x104], computes fib(n), and writes the result to RAM[0x100]. - let n = 10u32; + // Straight-line "fib-style" program: + // - x1 = 34 + // - x2 = 21 + // - x3 = x1 + x2 = 55 + // - mem[0x100] = x3 + // - HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 34, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 21, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Add, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 3, + imm: 0x100, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); let expected = F::from_u64(55); - let program_base = rv32_fibonacci_rom::RV32_FIBONACCI_ROM_BASE; - let program_bytes: &[u8] = &rv32_fibonacci_rom::RV32_FIBONACCI_ROM; - - println!( - "RV32 ELF .neo_start size: {} bytes ({} instructions)", - program_bytes.len(), - program_bytes.len() / 4 - ); - - let mut run = Rv32B1::from_rom(program_base, program_bytes) + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) .xlen(32) - .ram_bytes(0x800) - .ram_init_u32(/*addr=*/ 0x104, n) - .chunk_size(16) + .chunk_rows(program.len()) + .min_trace_len(program.len()) + .max_steps(program.len()) .shout_auto_minimal() .output(/*output_addr=*/ 0x100, /*expected_output=*/ expected) .prove() .expect("prove"); - println!( - "RV32 executed steps (trace len): {}", - run.riscv_trace_len().expect("trace len") - ); - println!( - "Circuit size (CCS): n_constraints={} m_variables={}", - run.ccs_num_constraints(), - run.ccs_num_variables() - ); - println!( - "Shout lookups used: {}", - run.shout_lookup_count().expect("shout lookup count") - ); - println!("Folds: {}", run.fold_count()); + run.verify().expect("verify"); - // Print proof size estimate + match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(32) + .chunk_rows(program.len()) + .min_trace_len(program.len()) + .max_steps(program.len()) + .shout_auto_minimal() + .output(/*output_addr=*/ 0x100, /*expected_output=*/ F::from_u64(56)) + .prove() { - let proof = run.proof(); - let num_steps = proof.steps.len(); - // Each MeInstance has exactly one commitment - let num_commitments: usize = proof - .steps - .iter() - .map(|s| { - s.fold.ccs_out.len() + s.fold.dec_children.len() + 1 // +1 for rlc_parent - + s.mem.cpu_me_claims_val.len() - + s.val_fold.as_ref().map(|v| v.dec_children.len() + 1).unwrap_or(0) - }) - .sum(); - // Commitment size: d * kappa * 8 bytes (d=54, kappa varies) - // Get d and kappa from the first commitment in the proof - let (d, kappa) = proof - .steps - .first() - .map(|s| (s.fold.rlc_parent.c.d, s.fold.rlc_parent.c.kappa)) - .unwrap_or((54, 2)); - let commitment_bytes = d * kappa * 8; - let estimated_bytes = num_commitments * commitment_bytes; - println!( - "Proof structure: {} steps, {} commitments (d={}, kappa={})", - num_steps, num_commitments, d, kappa - ); - println!( - "Estimated proof size (commitments only): {} bytes ({:.2} KB)", - estimated_bytes, - estimated_bytes as f64 / 1024.0 - ); - } - - let preview_len: usize = std::env::var("NIGHTSTREAM_WITNESS_PREVIEW_LEN") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(8); - let print_full = std::env::var("NIGHTSTREAM_PRINT_WITNESS_FULL").is_ok(); - - let steps_witness = run.steps_witness(); - println!( - "Step witness bundles: {} (expected folds={})", - steps_witness.len(), - run.fold_count() - ); - - for (fold_idx, step) in steps_witness.iter().enumerate() { - let mcs_inst = &step.mcs.0; - let mcs_wit = &step.mcs.1; - let x = &mcs_inst.x; - let w = &mcs_wit.w; - let z_len = x.len() + w.len(); - - let z_debug_sha256 = hex32(sha256_fields_concat(x, w)); - let w_debug_sha256 = hex32(sha256_field_slice(w)); - let Z_debug_sha256 = hex32(sha256_field_slice(mcs_wit.Z.as_slice())); - let z_nonzero = nonzero_concat(x, w); - - let (z_first, z_last) = preview_first_last_concat(x, w, preview_len); - let (Z_first, Z_last) = preview_first_last(mcs_wit.Z.as_slice(), preview_len); - - println!( - "Fold {fold_idx}: m_in={} x_len={} w_len={} z_len={} z_nonzero={} lut_instances={} mem_instances={} z_debug_sha256={} w_debug_sha256={} Z_debug_sha256={}", - mcs_inst.m_in, - x.len(), - w.len(), - z_len, - z_nonzero, - step.lut_instances.len(), - step.mem_instances.len(), - z_debug_sha256, - w_debug_sha256, - Z_debug_sha256, - ); - println!(" z_first={z_first:?}"); - println!(" z_last ={z_last:?}"); - println!(" Z_first={Z_first:?} (Z is {}x{})", mcs_wit.Z.rows(), mcs_wit.Z.cols()); - println!(" Z_last ={Z_last:?}"); - - if print_full { - println!( - " x_full={:?}", - x.iter().map(|v| v.as_canonical_u64()).collect::>() - ); - println!( - " w_full={:?}", - w.iter().map(|v| v.as_canonical_u64()).collect::>() - ); - } + Ok(mut bad_run) => assert!(bad_run.verify().is_err(), "wrong output claim must fail verification"), + Err(_) => {} } - - println!("Prove duration: {:?}", run.prove_duration()); - run.verify().expect("verify"); - println!("Verify duration: {:?}", run.verify_duration().expect("verify duration")); - - assert!( - matches!( - run.verify_output_claim(/*output_addr=*/ 0x100, /*expected_output=*/ F::from_u64(56)), - Ok(false) | Err(_) - ), - "wrong output claim must not verify" - ); } diff --git a/crates/neo-fold/riscv-tests/test_riscv_program_compiled_full_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_program_compiled_full_prove_verify.rs index 38ada8c0..581bbfbc 100644 --- a/crates/neo-fold/riscv-tests/test_riscv_program_compiled_full_prove_verify.rs +++ b/crates/neo-fold/riscv-tests/test_riscv_program_compiled_full_prove_verify.rs @@ -1,4 +1,4 @@ -//! End-to-end prove+verify for a small *compiled* RV32 guest program under the B1 shared-bus step circuit. +//! End-to-end prove+verify for a small *compiled* RV32 guest program under the trace wiring circuit. //! //! The guest is authored in Rust under `riscv-tests/guests/rv32-smoke/` and its ROM bytes are committed in //! `riscv-tests/binaries/rv32_smoke_rom.rs` so this test doesn't need to cross-compile at runtime. @@ -8,7 +8,7 @@ #[path = "binaries/rv32_smoke_rom.rs"] mod rv32_smoke_rom; -use neo_fold::riscv_shard::Rv32B1; +use neo_fold::riscv_trace_shard::Rv32TraceWiring; use neo_math::F; use p3_field::PrimeCharacteristicRing; @@ -16,10 +16,9 @@ use p3_field::PrimeCharacteristicRing; fn test_riscv_program_compiled_full_prove_verify() { let program_base = rv32_smoke_rom::RV32_SMOKE_ROM_BASE; let program_bytes: &[u8] = &rv32_smoke_rom::RV32_SMOKE_ROM; - let mut run = Rv32B1::from_rom(program_base, program_bytes) + let mut run = Rv32TraceWiring::from_rom(program_base, program_bytes) .xlen(32) - .ram_bytes(0x200) - .chunk_size(4) + .chunk_rows(4) .shout_auto_minimal() .output( /*output_addr=*/ 0x100, @@ -32,17 +31,17 @@ fn test_riscv_program_compiled_full_prove_verify() { run.verify().expect("verify"); println!("Verify duration: {:?}", run.verify_duration().expect("verify duration")); - let end = run.final_boundary_state().expect("boundary"); - assert_eq!(end.regs_final[4], F::from_u64(0x100c)); - - assert!( - matches!( - run.verify_output_claim( - /*output_addr=*/ 0x100, - /*expected_output=*/ F::from_u64(0x100d) - ), - Ok(false) | Err(_) - ), - "wrong output claim must not verify" - ); + match Rv32TraceWiring::from_rom(program_base, program_bytes) + .xlen(32) + .chunk_rows(4) + .shout_auto_minimal() + .output( + /*output_addr=*/ 0x100, + /*expected_output=*/ F::from_u64(0x100d), + ) + .prove() + { + Ok(mut bad_run) => assert!(bad_run.verify().is_err(), "wrong output claim must fail verification"), + Err(_) => {} + } } diff --git a/crates/neo-fold/riscv-tests/test_riscv_program_crosscheck.rs b/crates/neo-fold/riscv-tests/test_riscv_program_crosscheck.rs index 4d9b9333..3bbeb2ea 100644 --- a/crates/neo-fold/riscv-tests/test_riscv_program_crosscheck.rs +++ b/crates/neo-fold/riscv-tests/test_riscv_program_crosscheck.rs @@ -7,7 +7,7 @@ //! paper-exact crosschecks from dominating CI time. use neo_fold::pi_ccs::FoldingMode; -use neo_fold::riscv_shard::Rv32B1; +use neo_fold::riscv_trace_shard::Rv32TraceWiring; use neo_math::F; use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; use neo_reductions::engines::CrosscheckCfg; @@ -42,19 +42,17 @@ fn test_riscv_program_crosscheck_tiny_trace() { outputs: true, }; - let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program_bytes) + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) .xlen(32) - .ram_bytes(0x40) - .chunk_size(1) + .chunk_rows(1) + .min_trace_len(3) .max_steps(3) .mode(FoldingMode::OptimizedWithCrosscheck(crosscheck_cfg)) + .reg_output_claim(/*reg=*/ 1, /*expected=*/ F::from_u64(12)) .prove() .expect("prove"); run.verify().expect("verify"); - - let end = run.final_boundary_state().expect("boundary"); - assert_eq!(end.regs_final[1], F::from_u64(12)); } #[test] @@ -87,19 +85,17 @@ fn test_riscv_program_crosscheck_full_flags_one_step() { outputs: true, }; - let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program_bytes) + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) .xlen(32) - .ram_bytes(0x40) - .chunk_size(1) + .chunk_rows(1) + .min_trace_len(1) .max_steps(1) .mode(FoldingMode::OptimizedWithCrosscheck(crosscheck_cfg)) + .reg_output_claim(/*reg=*/ 1, /*expected=*/ F::from_u64(7)) .prove() .expect("prove"); run.verify().expect("verify"); - - let end = run.final_boundary_state().expect("boundary"); - assert_eq!(end.regs_final[1], F::from_u64(7)); } #[test] @@ -132,17 +128,15 @@ fn test_riscv_program_crosscheck_full_flags_two_steps() { outputs: true, }; - let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program_bytes) + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) .xlen(32) - .ram_bytes(0x40) - .chunk_size(1) + .chunk_rows(1) + .min_trace_len(2) .max_steps(2) .mode(FoldingMode::OptimizedWithCrosscheck(crosscheck_cfg)) + .reg_output_claim(/*reg=*/ 1, /*expected=*/ F::from_u64(12)) .prove() .expect("prove"); run.verify().expect("verify"); - - let end = run.final_boundary_state().expect("boundary"); - assert_eq!(end.regs_final[1], F::from_u64(12)); } diff --git a/crates/neo-fold/riscv-tests/test_riscv_program_full_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_program_full_prove_verify.rs index c8c80225..787b7735 100644 --- a/crates/neo-fold/riscv-tests/test_riscv_program_full_prove_verify.rs +++ b/crates/neo-fold/riscv-tests/test_riscv_program_full_prove_verify.rs @@ -1,85 +1,12 @@ -//! End-to-end prove+verify for a small RV32 program under the B1 shared-bus step circuit. -//! -//! This exercises: -//! - B1 instruction fetch via `PROG_ID` Twist reads -//! - shared CPU bus tail wiring (Twist + Shout) -//! - implicit Shout table spec (`LutTableSpec::RiscvOpcode`) -//! - the RV32 B1 step CCS glue constraints +//! End-to-end prove+verify for small RV32 programs under the trace wiring circuit. #![allow(non_snake_case)] -use std::collections::HashMap; - -use neo_ajtai::Commitment as Cmt; -use neo_ccs::matrix::Mat; -use neo_ccs::traits::SModuleHomomorphism; -use neo_fold::pi_ccs::FoldingMode; -use neo_fold::riscv_shard::fold_shard_verify_rv32_b1_with_statement_mem_init; -use neo_fold::shard::{fold_shard_prove, CommitMixers}; -use neo_math::{F, K}; -use neo_memory::builder::build_shard_witness_shared_cpu_bus; -use neo_memory::plain::PlainMemLayout; -use neo_memory::riscv::ccs::{build_rv32_b1_step_ccs, rv32_b1_chunk_to_witness, rv32_b1_shared_cpu_bus_config}; -use neo_memory::riscv::lookups::{ - encode_program, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, -}; -use neo_memory::riscv::rom_init::prog_init_words; -use neo_memory::riscv::shard::extract_boundary_state; -use neo_memory::witness::{LutTableSpec, StepInstanceBundle}; -use neo_memory::R1csCpu; -use neo_params::NeoParams; -use neo_transcript::{Poseidon2Transcript, Transcript}; +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use neo_memory::riscv::lookups::{encode_program, BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode}; use p3_field::PrimeCharacteristicRing; -#[derive(Clone, Copy, Default)] -struct DummyCommit; - -impl SModuleHomomorphism for DummyCommit { - fn commit(&self, z: &Mat) -> Cmt { - Cmt::zeros(z.rows(), 1) - } - - fn project_x(&self, z: &Mat, m_in: usize) -> Mat { - let rows = z.rows(); - let mut out = Mat::zero(rows, m_in, F::ZERO); - for r in 0..rows { - for c in 0..m_in.min(z.cols()) { - out[(r, c)] = z[(r, c)]; - } - } - out - } -} - -fn default_mixers() -> CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt> { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(neo_math::D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(neo_math::D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } -} - -fn pow2_ceil_k(min_k: usize) -> (usize, usize) { - let k = min_k.next_power_of_two().max(2); - let d = k.trailing_zeros() as usize; - (k, d) -} - -fn add_only_table_specs(xlen: usize) -> HashMap { - HashMap::from([( - 3u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Add, - xlen, - }, - )]) -} - #[test] fn test_riscv_program_full_prove_verify() { let xlen = 32usize; @@ -119,108 +46,21 @@ fn test_riscv_program_full_prove_verify() { RiscvInstruction::Halt, ]; let max_steps = program.len(); - let program_bytes = encode_program(&program); - let mut vm = RiscvCpu::new(xlen); - vm.load_program(0, program); - let twist = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - - // Keep k small to reduce bus tail width and proof work. - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - // Build CCS + shared-bus CPU arithmetization. - // Keep the Shout bus lean: this program only needs ADD (for ADD/ADDI and effective address calculation). - let shout_table_ids: Vec = vec![3u32]; - let add_idx = shout_table_ids - .iter() - .position(|&id| id == 3u32) - .expect("ADD table id present"); - let add_lane = u32::try_from(add_idx).expect("ADD lane index fits u32"); - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs_base.n).expect("params"); - - let table_specs = add_only_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs_base, - params.clone(), - DummyCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .expect("cfg"), - 1, - ) - .expect("shared bus inject"); - // Build shared-bus step bundles (includes CPU MCS + metadata-only mem/lut instances). - let lut_tables = HashMap::new(); - let steps = build_shard_witness_shared_cpu_bus::<_, Cmt, K, _, _, _>( - vm, - twist, - shout, - /*max_steps=*/ max_steps, - /*chunk_size=*/ 1, - &mem_layouts, - &lut_tables, - &table_specs, - &HashMap::new(), - &initial_mem, - &cpu, - ) - .expect("build shard witness"); - - let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(xlen) + .chunk_rows(1) + .max_steps(max_steps) + .shout_ops([RiscvOpcode::Add]) + .prove() + .expect("prove"); - let mixers = default_mixers(); - let mut tr_prove = Poseidon2Transcript::new(b"riscv-b1-full"); - // PaperExact is intentionally slow (brute-force oracle) and can make this end-to-end - // test take minutes. Use the optimized engine here and keep PaperExact covered by - // smaller unit tests. - let proof = fold_shard_prove( - FoldingMode::Optimized, - &mut tr_prove, - ¶ms, - &cpu.ccs, - &steps, - &[], - &[], - &DummyCommit::default(), - mixers, - ) - .expect("prove"); + run.verify().expect("verify"); // Ensure the Shout addr-pre proof skips inactive tables. - // This program uses only the ADD lookup; LUI and HALT use no Shout lookups and should skip entirely. - let mut saw_skipped = false; - let mut saw_add_only = false; + let proof = run.proof(); + let mut saw_active = false; for step in &proof.steps { let pre = &step.mem.shout_addr_pre; let active_lanes: Vec = pre @@ -230,57 +70,13 @@ fn test_riscv_program_full_prove_verify() { .collect(); if active_lanes.is_empty() { assert!(pre.groups.iter().all(|g| g.round_polys.is_empty())); - saw_skipped = true; continue; } - assert_eq!( - active_lanes, - vec![add_lane], - "expected ADD-only Shout addr-pre active_lanes" - ); let rounds_total: usize = pre.groups.iter().map(|g| g.round_polys.len()).sum(); - assert_eq!(rounds_total, 1, "ADD-only step must include 1 proof"); - saw_add_only = true; + assert!(rounds_total > 0, "active lanes must carry addr-pre round polys"); + saw_active = true; } - assert!(saw_skipped, "expected at least one no-Shout step (mask=0)"); - assert!(saw_add_only, "expected at least one ADD-lookup step (mask=ADD)"); - - let mut tr_verify = Poseidon2Transcript::new(b"riscv-b1-full"); - let _ = fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_verify, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &steps_public, - &[], - &proof, - mixers, - &layout, - ) - .expect("verify"); - - let mut bad_steps = steps_public.clone(); - bad_steps[1].mcs_inst.x[layout.pc0] += F::ONE; - let mut tr_bad = Poseidon2Transcript::new(b"riscv-b1-full"); - assert!( - fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_bad, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &bad_steps, - &[], - &proof, - mixers, - &layout, - ) - .is_err(), - "expected step linking failure" - ); + assert!(saw_active, "expected at least one active addr-pre step"); // Tamper: change Shout addr-pre active_lanes; verification must fail. let mut bad_proof = proof.clone(); @@ -304,163 +100,46 @@ fn test_riscv_program_full_prove_verify() { .expect("expected at least one active Shout addr-pre group"); group.active_lanes.clear(); group.round_polys.clear(); - let mut tr_bad_mask = Poseidon2Transcript::new(b"riscv-b1-full"); assert!( - fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_bad_mask, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &steps_public, - &[], - &bad_proof, - mixers, - &layout, - ) - .is_err(), + run.verify_proof(&bad_proof).is_err(), "expected Shout addr-pre active_lanes mismatch failure" ); } #[test] -fn test_riscv_statement_mem_init_mismatch_fails() { - let xlen = 32usize; - let program = vec![RiscvInstruction::Halt]; - let max_steps = program.len(); - +fn test_riscv_wrong_output_claim_fails_verify() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 7, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 0x100, + }, + RiscvInstruction::Halt, + ]; let program_bytes = encode_program(&program); - // Keep k small to reduce bus tail width and proof work. - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x40); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - // Keep the Shout bus lean: this program uses no Shout lookups, but include ADD to keep the bus schema stable. - let table_specs = add_only_table_specs(xlen); - let shout_table_ids: Vec = vec![3u32]; - - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs_base.n).expect("params"); - - let cpu = R1csCpu::new( - ccs_base, - params.clone(), - DummyCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .expect("cfg"), - 1, - ) - .expect("shared bus inject"); - - let mut vm = RiscvCpu::new(xlen); - vm.load_program(0, program); - let twist = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - - let steps = build_shard_witness_shared_cpu_bus::<_, Cmt, K, _, _, _>( - vm, - twist, - shout, - max_steps, - 1, - &mem_layouts, - &HashMap::new(), - &table_specs, - &HashMap::new(), - &initial_mem, - &cpu, - ) - .expect("build shard witness"); - let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); - - let mixers = default_mixers(); - let mut tr_prove = Poseidon2Transcript::new(b"riscv-b1-stmt-mem-init"); - let proof = fold_shard_prove( - FoldingMode::Optimized, - &mut tr_prove, - ¶ms, - &cpu.ccs, - &steps, - &[], - &[], - &DummyCommit::default(), - mixers, - ) - .expect("prove"); - - // Sanity: correct statement must verify. - let mut tr_verify = Poseidon2Transcript::new(b"riscv-b1-stmt-mem-init"); - let _ = fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_verify, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &steps_public, - &[], - &proof, - mixers, - &layout, - ) - .expect("verify"); - - // Mismatch the *statement* initial memory (RAM starts non-zero) while keeping the proof fixed. - // Verification must fail at the chunk0 init check. - let mut bad_statement_initial_mem = initial_mem.clone(); - bad_statement_initial_mem.insert((0u32, 0u64), F::ONE); - - let mut tr_bad = Poseidon2Transcript::new(b"riscv-b1-stmt-mem-init"); - assert!( - fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_bad, - ¶ms, - &cpu.ccs, - &mem_layouts, - &bad_statement_initial_mem, - &steps_public, - &[], - &proof, - mixers, - &layout, - ) - .is_err(), - "expected statement init mismatch failure" - ); + match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(32) + .chunk_rows(1) + .max_steps(program.len()) + .output_claim(0x100, F::from_u64(8)) + .prove() + { + Ok(mut run) => assert!(run.verify().is_err(), "wrong output claim must fail verification"), + Err(_) => {} + } } #[test] #[ignore = "manual benchmark sweep; run with --ignored --nocapture"] -fn perf_rv32_b1_chunk_size_sweep() { +fn perf_rv32_trace_chunk_rows_sweep() { use std::time::Instant; let xlen = 32usize; @@ -470,203 +149,78 @@ fn perf_rv32_b1_chunk_size_sweep() { rd: 1, rs1: 0, imm: 1, - }, // x1 = 1 + }, RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 2, rs1: 0, imm: 2, - }, // x2 = 2 + }, RiscvInstruction::Branch { - cond: neo_memory::riscv::lookups::BranchCondition::Eq, + cond: BranchCondition::Eq, rs1: 1, rs2: 2, imm: 8, - }, // not taken + }, RiscvInstruction::Store { op: RiscvMemOp::Sw, rs1: 0, rs2: 1, imm: 0, - }, // mem[0] = x1 - RiscvInstruction::Jal { rd: 5, imm: 8 }, // jump over the next instruction + }, + RiscvInstruction::Jal { rd: 5, imm: 8 }, RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 3, rs1: 0, imm: 123, - }, // skipped + }, RiscvInstruction::Load { op: RiscvMemOp::Lw, rd: 3, rs1: 0, imm: 0, - }, // x3 = mem[0] + }, RiscvInstruction::Halt, ]; let program_bytes = encode_program(&program); let max_steps = 64usize; - // Keep k small to reduce bus tail width and proof work. - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x40); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - fn table_specs_from_ids(ids: &[u32], xlen: usize) -> HashMap { - ids.iter() - .copied() - .map(|id| { - let opcode = match id { - 0 => RiscvOpcode::And, - 1 => RiscvOpcode::Xor, - 2 => RiscvOpcode::Or, - 3 => RiscvOpcode::Add, - 4 => RiscvOpcode::Sub, - 5 => RiscvOpcode::Slt, - 6 => RiscvOpcode::Sltu, - 7 => RiscvOpcode::Sll, - 8 => RiscvOpcode::Srl, - 9 => RiscvOpcode::Sra, - 10 => RiscvOpcode::Eq, - 11 => RiscvOpcode::Neq, - 12 => RiscvOpcode::Mul, - 13 => RiscvOpcode::Mulh, - 14 => RiscvOpcode::Mulhu, - 15 => RiscvOpcode::Mulhsu, - 16 => RiscvOpcode::Div, - 17 => RiscvOpcode::Divu, - 18 => RiscvOpcode::Rem, - 19 => RiscvOpcode::Remu, - _ => panic!("unsupported RV32 B1 table_id={id}"), - }; - (id, LutTableSpec::RiscvOpcode { opcode, xlen }) - }) - .collect() - } - - let profiles: &[(&str, &[u32])] = &[ - ("min3", neo_memory::riscv::ccs::RV32_B1_SHOUT_PROFILE_MIN3), - ("full12", neo_memory::riscv::ccs::RV32_B1_SHOUT_PROFILE_FULL12), + let profiles: &[(&str, &[RiscvOpcode])] = &[ + ("minimal", &[RiscvOpcode::Add]), + ("extended", &[RiscvOpcode::Add, RiscvOpcode::Sub, RiscvOpcode::Sltu]), ]; - let mixers = default_mixers(); - - for (profile_name, shout_table_ids) in profiles { - let table_specs = table_specs_from_ids(shout_table_ids, xlen); - println!("\n== profile={profile_name} shout_tables={} ==", shout_table_ids.len()); - - for chunk_size in [1usize, 2, 4, 8, 16] { - let mut vm = RiscvCpu::new(xlen); - vm.load_program(0, program.clone()); - let twist = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, shout_table_ids, chunk_size).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs_base.n).expect("params"); - - let cpu = R1csCpu::new( - ccs_base, - params.clone(), - DummyCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .expect("cfg"), - chunk_size, - ) - .expect("shared bus inject"); - - let t_build = Instant::now(); - let steps = build_shard_witness_shared_cpu_bus::<_, Cmt, K, _, _, _>( - vm, - twist, - shout, - max_steps, - chunk_size, - &mem_layouts, - &HashMap::new(), - &table_specs, - &HashMap::new(), - &initial_mem, - &cpu, - ) - .expect("build shard witness"); - let build_dur = t_build.elapsed(); - - let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); - - let mut tr_prove = Poseidon2Transcript::new(b"riscv-b1-chunk-sweep"); - let t_prove = Instant::now(); - let proof = fold_shard_prove( - FoldingMode::Optimized, - &mut tr_prove, - ¶ms, - &cpu.ccs, - &steps, - &[], - &[], - &DummyCommit::default(), - mixers, - ) - .expect("prove"); - let prove_dur = t_prove.elapsed(); - - let mut tr_verify = Poseidon2Transcript::new(b"riscv-b1-chunk-sweep"); - let t_verify = Instant::now(); - let _ = fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_verify, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &steps_public, - &[], - &proof, - mixers, - &layout, - ) - .expect("verify"); - let verify_dur = t_verify.elapsed(); + for (profile_name, ops) in profiles { + println!("\n== profile={profile_name} shout_tables={} ==", ops.len()); + + for chunk_rows in [1usize, 2, 4, 8, 16] { + let t_total = Instant::now(); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(xlen) + .chunk_rows(chunk_rows) + .max_steps(max_steps) + .shout_ops(ops.iter().copied()) + .prove() + .expect("prove"); + let total_dur = t_total.elapsed(); + let prove_dur = run.prove_duration(); + + run.verify().expect("verify"); + let verify_dur = run.verify_duration().expect("verify duration"); + let folds = run.fold_count(); println!( - "chunk_size={chunk_size:<2} chunks={:<3} build={:?} prove={:?} verify={:?}", - steps_public.len(), - build_dur, - prove_dur, - verify_dur + "chunk_rows={chunk_rows:<2} folds={folds:<3} prove={:?} verify={:?} total={:?}", + prove_dur, verify_dur, total_dur ); } } } #[test] -fn test_riscv_program_chunk_size_equivalence() { +fn test_riscv_program_chunk_rows_equivalence() { let xlen = 32usize; let program = vec![ RiscvInstruction::IAlu { @@ -674,214 +228,50 @@ fn test_riscv_program_chunk_size_equivalence() { rd: 1, rs1: 0, imm: 1, - }, // x1 = 1 + }, RiscvInstruction::Store { op: RiscvMemOp::Sw, rs1: 0, rs2: 1, imm: 0, - }, // mem[0] = x1 + }, RiscvInstruction::Load { op: RiscvMemOp::Lw, rd: 2, rs1: 0, imm: 0, - }, // x2 = mem[0] + }, RiscvInstruction::Halt, ]; - - let program_bytes = encode_program(&program); let max_steps = program.len(); + let program_bytes = encode_program(&program); - // Keep k small to reduce bus tail width and proof work. - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x40); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - // Keep the Shout bus lean: this program only needs ADD (for ADDI and effective address calculation). - let table_specs = add_only_table_specs(xlen); - let shout_table_ids: Vec = vec![3u32]; - - let mixers = default_mixers(); - - let run = |chunk_size: usize| -> (neo_memory::riscv::ccs::Rv32B1Layout, Vec>) { - let mut vm = RiscvCpu::new(xlen); - vm.load_program(0, program.clone()); - let twist = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, chunk_size).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs_base.n).expect("params"); - - let cpu = R1csCpu::new( - ccs_base, - params.clone(), - DummyCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .expect("cfg"), - chunk_size, - ) - .expect("shared bus inject"); - - let steps = build_shard_witness_shared_cpu_bus::<_, Cmt, K, _, _, _>( - vm, - twist, - shout, - max_steps, - chunk_size, - &mem_layouts, - &HashMap::new(), - &table_specs, - &HashMap::new(), - &initial_mem, - &cpu, - ) - .expect("build shard witness"); - let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); - - let mut tr_prove = Poseidon2Transcript::new(b"riscv-b1-chunk-eq"); - let proof = fold_shard_prove( - FoldingMode::Optimized, - &mut tr_prove, - ¶ms, - &cpu.ccs, - &steps, - &[], - &[], - &DummyCommit::default(), - mixers, - ) - .expect("prove"); - - let mut tr_verify = Poseidon2Transcript::new(b"riscv-b1-chunk-eq"); - let _ = fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_verify, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &steps_public, - &[], - &proof, - mixers, - &layout, - ) - .expect("verify"); - - // Tamper boundary chaining for chunk_size>1. - if chunk_size > 1 && steps_public.len() > 1 { - let mut bad_steps = steps_public.clone(); - bad_steps[1].mcs_inst.x[layout.pc0] += F::ONE; - let mut tr_bad = Poseidon2Transcript::new(b"riscv-b1-chunk-eq"); - assert!( - fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_bad, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &bad_steps, - &[], - &proof, - mixers, - &layout, - ) - .is_err(), - "expected step linking failure for chunk_size={chunk_size}" - ); - - // Also ensure the `halted_out -> halted_in` step linking is enforced (even when both are 0). - let mut bad_steps = steps_public.clone(); - bad_steps[1].mcs_inst.x[layout.halted_in] = F::ONE; - let mut tr_bad_halt = Poseidon2Transcript::new(b"riscv-b1-chunk-eq"); - assert!( - fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_bad_halt, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &bad_steps, - &[], - &proof, - mixers, - &layout, - ) - .is_err(), - "expected halted_in/out step linking failure for chunk_size={chunk_size}" - ); - } - - (layout, steps_public) - }; - - let (layout_1, steps_1) = run(1); - let (layout_2, steps_2) = run(2); - - let start_1 = extract_boundary_state(&layout_1, &steps_1[0].mcs_inst.x).expect("boundary"); - let start_2 = extract_boundary_state(&layout_2, &steps_2[0].mcs_inst.x).expect("boundary"); - assert_eq!(start_1.pc0, start_2.pc0, "pc0 must be chunk-size invariant"); - assert_eq!(start_1.regs0, start_2.regs0, "regs0 must be chunk-size invariant"); - - let end_1 = extract_boundary_state(&layout_1, &steps_1.last().expect("non-empty").mcs_inst.x).expect("boundary"); - let end_2 = extract_boundary_state(&layout_2, &steps_2.last().expect("non-empty").mcs_inst.x).expect("boundary"); - assert_eq!(end_1.pc_final, end_2.pc_final, "pc_final must be chunk-size invariant"); - assert_eq!( - end_1.regs_final, end_2.regs_final, - "regs_final must be chunk-size invariant" - ); - - // Stronger equivalence: each chunk boundary in chunk_size=2 corresponds to the same boundary - // after the same number of steps in chunk_size=1. - let n = steps_1.len(); - let k = 2usize; - assert_eq!(n, max_steps, "chunk_size=1 should produce one chunk per step"); - assert_eq!(steps_2.len(), n.div_ceil(k), "unexpected chunk count for chunk_size=2"); - - for c in 0..steps_2.len() { - let s = c * k; - let e = ((c + 1) * k).min(n) - 1; - let st_k = extract_boundary_state(&layout_2, &steps_2[c].mcs_inst.x).expect("boundary"); - let st_1s = extract_boundary_state(&layout_1, &steps_1[s].mcs_inst.x).expect("boundary"); - let st_1e = extract_boundary_state(&layout_1, &steps_1[e].mcs_inst.x).expect("boundary"); - - assert_eq!(st_k.pc0, st_1s.pc0, "pc0 mismatch at chunk {c}"); - assert_eq!(st_k.regs0, st_1s.regs0, "regs0 mismatch at chunk {c}"); - assert_eq!(st_k.halted_in, st_1s.halted_in, "halted_in mismatch at chunk {c}"); - - assert_eq!(st_k.pc_final, st_1e.pc_final, "pc_final mismatch at chunk {c}"); - assert_eq!(st_k.regs_final, st_1e.regs_final, "regs_final mismatch at chunk {c}"); - assert_eq!(st_k.halted_out, st_1e.halted_out, "halted_out mismatch at chunk {c}"); - } + let mut run_1 = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(xlen) + .chunk_rows(1) + .max_steps(max_steps) + .shout_ops([RiscvOpcode::Add]) + .prove() + .expect("prove chunk_rows=1"); + run_1.verify().expect("verify chunk_rows=1"); + + let mut run_2 = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(xlen) + .chunk_rows(2) + .max_steps(max_steps) + .shout_ops([RiscvOpcode::Add]) + .prove() + .expect("prove chunk_rows=2"); + run_2.verify().expect("verify chunk_rows=2"); + + let first_1 = run_1.exec_table().rows.first().expect("non-empty trace"); + let first_2 = run_2.exec_table().rows.first().expect("non-empty trace"); + assert_eq!(first_1.pc_before, first_2.pc_before, "pc_before must be invariant"); + + let last_1 = run_1.exec_table().rows.last().expect("non-empty trace"); + let last_2 = run_2.exec_table().rows.last().expect("non-empty trace"); + assert_eq!(last_1.pc_after, last_2.pc_after, "pc_after must be invariant"); + assert_eq!(run_1.trace_len(), run_2.trace_len(), "trace length must match"); } #[test] @@ -901,133 +291,50 @@ fn test_riscv_program_rv32m_full_prove_verify() { imm: 3, }, // x2 = 3 RiscvInstruction::RAlu { - op: RiscvOpcode::Mul, + op: RiscvOpcode::Slt, rd: 3, rs1: 1, rs2: 2, - }, + }, // x3 = 1 (signed compare) RiscvInstruction::RAlu { - op: RiscvOpcode::Div, + op: RiscvOpcode::Sltu, rd: 4, rs1: 1, rs2: 2, - }, + }, // x4 = 0 (unsigned compare) RiscvInstruction::Halt, ]; let max_steps = program.len(); - let program_bytes = encode_program(&program); - let mut vm = RiscvCpu::new(xlen); - vm.load_program(0, program); - let twist = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x40); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - // Minimal table set: ADD (for ADD/ADDI) + SLTU (for signed DIV/REM remainder-bound check when divisor != 0). - let shout_table_ids: Vec = vec![3, 6]; - let table_specs = HashMap::from([ - ( - 3u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Add, - xlen, - }, - ), - ( - 6u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Sltu, - xlen, - }, - ), - ]); - - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs_base.n).expect("params"); - - let cpu = R1csCpu::new( - ccs_base, - params.clone(), - DummyCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .expect("cfg"), - 1, - ) - .expect("shared bus inject"); - - let steps = build_shard_witness_shared_cpu_bus::<_, Cmt, K, _, _, _>( - vm, - twist, - shout, - /*max_steps=*/ max_steps, - /*chunk_size=*/ 1, - &mem_layouts, - &HashMap::new(), - &table_specs, - &HashMap::new(), - &initial_mem, - &cpu, - ) - .expect("build shard witness"); - let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(xlen) + .chunk_rows(1) + .min_trace_len(max_steps) + .max_steps(max_steps) + .shout_ops([RiscvOpcode::Add, RiscvOpcode::Slt, RiscvOpcode::Sltu]) + .reg_output_claim(/*reg=*/ 3, F::from_u64(1)) + .reg_output_claim(/*reg=*/ 4, F::from_u64(0)) + .prove() + .expect("prove"); - let mixers = default_mixers(); - let mut tr_prove = Poseidon2Transcript::new(b"riscv-b1-rv32m-full"); - let proof = fold_shard_prove( - FoldingMode::Optimized, - &mut tr_prove, - ¶ms, - &cpu.ccs, - &steps, - &[], - &[], - &DummyCommit::default(), - mixers, - ) - .expect("prove"); + run.verify().expect("verify"); - let mut tr_verify = Poseidon2Transcript::new(b"riscv-b1-rv32m-full"); - let _ = fold_shard_verify_rv32_b1_with_statement_mem_init( - FoldingMode::Optimized, - &mut tr_verify, - ¶ms, - &cpu.ccs, - &mem_layouts, - &initial_mem, - &steps_public, - &[], - &proof, - mixers, - &layout, - ) - .expect("verify"); + let compare_rows: Vec = run + .exec_table() + .rows + .iter() + .enumerate() + .filter_map(|(idx, row)| { + matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Slt | RiscvOpcode::Sltu, + .. + }) + ) + .then_some(idx) + }) + .collect(); + assert_eq!(compare_rows, vec![2, 3], "expected compare rows on SLT/SLTU steps"); } diff --git a/crates/neo-fold/riscv-tests/test_riscv_u64_output_compiled_full_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_u64_output_compiled_full_prove_verify.rs index 6deef57b..da28747b 100644 --- a/crates/neo-fold/riscv-tests/test_riscv_u64_output_compiled_full_prove_verify.rs +++ b/crates/neo-fold/riscv-tests/test_riscv_u64_output_compiled_full_prove_verify.rs @@ -1,4 +1,4 @@ -//! End-to-end prove+verify for a small *compiled* RV32 guest program under the B1 shared-bus step circuit. +//! End-to-end prove+verify for a small *compiled* RV32 guest program under the trace wiring circuit. //! //! The guest is authored in Rust under `riscv-tests/guests/rv32-u64-output/` and its ROM bytes are committed in //! `riscv-tests/binaries/rv32_u64_output_rom.rs` so this test doesn't need to cross-compile at runtime. @@ -8,9 +8,8 @@ #[path = "binaries/rv32_u64_output_rom.rs"] mod rv32_u64_output_rom; -use neo_fold::riscv_shard::Rv32B1; +use neo_fold::riscv_trace_shard::Rv32TraceWiring; use neo_math::F; -use neo_memory::output_check::ProgramIO; use p3_field::PrimeCharacteristicRing; #[test] @@ -22,10 +21,9 @@ fn test_riscv_u64_output_compiled_full_prove_verify() { let program_base = rv32_u64_output_rom::RV32_U64_OUTPUT_ROM_BASE; let program_bytes: &[u8] = &rv32_u64_output_rom::RV32_U64_OUTPUT_ROM; - let mut run = Rv32B1::from_rom(program_base, program_bytes) + let mut run = Rv32TraceWiring::from_rom(program_base, program_bytes) .xlen(32) - .ram_bytes(0x200) - .chunk_size(4) + .chunk_rows(4) .shout_auto_minimal() .output_claim(/*addr=*/ 0x100, /*value=*/ out_lo) .output_claim(/*addr=*/ 0x104, /*value=*/ out_hi) @@ -36,11 +34,15 @@ fn test_riscv_u64_output_compiled_full_prove_verify() { run.verify().expect("verify"); println!("Verify duration: {:?}", run.verify_duration().expect("verify duration")); - let wrong = ProgramIO::new() - .with_output(0x100, out_lo) - .with_output(0x104, F::from_u64(0)); - assert!( - matches!(run.verify_output_claims(wrong), Ok(false) | Err(_)), - "wrong output claims must not verify" - ); + match Rv32TraceWiring::from_rom(program_base, program_bytes) + .xlen(32) + .chunk_rows(4) + .shout_auto_minimal() + .output_claim(/*addr=*/ 0x100, /*value=*/ out_lo) + .output_claim(/*addr=*/ 0x104, /*value=*/ F::from_u64(0)) + .prove() + { + Ok(mut bad_run) => assert!(bad_run.verify().is_err(), "wrong output claims must fail verification"), + Err(_) => {} + } } diff --git a/crates/neo-fold/src/lib.rs b/crates/neo-fold/src/lib.rs index dae0b088..0f29f39b 100644 --- a/crates/neo-fold/src/lib.rs +++ b/crates/neo-fold/src/lib.rs @@ -26,8 +26,7 @@ pub mod finalize; // Shard-level folding (CPU + Memory Sidecar) pub mod shard; -// Convenience wrappers for RV32 shard verification -pub mod riscv_shard; +pub mod riscv_trace_shard; // Output binding integration pub mod output_binding; diff --git a/crates/neo-fold/src/memory_sidecar/claim_plan.rs b/crates/neo-fold/src/memory_sidecar/claim_plan.rs index b83f3aed..b2a9abb0 100644 --- a/crates/neo-fold/src/memory_sidecar/claim_plan.rs +++ b/crates/neo-fold/src/memory_sidecar/claim_plan.rs @@ -1,6 +1,7 @@ use neo_ajtai::Commitment as Cmt; use neo_math::{F, K}; -use neo_memory::witness::{LutInstance, MemInstance, StepInstanceBundle}; +use neo_memory::riscv::lookups::RiscvOpcode; +use neo_memory::witness::{LutInstance, LutTableSpec, MemInstance, StepInstanceBundle}; use crate::PiCcsError; @@ -13,8 +14,10 @@ pub struct TimeClaimMeta { #[derive(Clone, Debug)] pub struct ShoutLaneTimeClaimIdx { - pub value: usize, - pub adapter: usize, + pub value: Option, + pub adapter: Option, + pub event_table_hash: Option, + pub gamma_group: Option, } #[derive(Clone, Debug)] @@ -24,6 +27,29 @@ pub struct ShoutTimeClaimIdx { pub ell_addr: usize, } +#[derive(Clone, Debug)] +pub struct ShoutGammaGroupLaneRef { + pub flat_lane_idx: usize, + pub inst_idx: usize, + pub lane_idx: usize, +} + +#[derive(Clone, Debug)] +pub struct ShoutGammaGroupSpec { + pub key: u64, + pub ell_addr: usize, + pub lanes: Vec, +} + +#[derive(Clone, Debug)] +pub struct ShoutGammaGroupTimeClaimIdx { + pub key: u64, + pub ell_addr: usize, + pub lanes: Vec, + pub value: usize, + pub adapter: usize, +} + #[derive(Clone, Debug)] pub struct TwistTimeClaimIdx { pub read_check: usize, @@ -41,21 +67,111 @@ pub struct RouteATimeClaimPlan { pub claim_idx_start: usize, pub claim_idx_end: usize, pub shout: Vec, + pub shout_gamma_groups: Vec, + pub shout_event_trace_hash: Option, pub twist: Vec, + pub wb_bool: Option, + pub wp_quiescence: Option, + pub decode_fields: Option, + pub decode_immediates: Option, + pub width_bitness: Option, + pub width_quiescence: Option, + pub width_selector_linkage: Option, + pub width_load_semantics: Option, + pub width_store_semantics: Option, + pub control_next_pc_linear: Option, + pub control_next_pc_control: Option, + pub control_branch_semantics: Option, + pub control_writeback: Option, } impl RouteATimeClaimPlan { + pub fn derive_shout_gamma_groups_for_instances<'a, LI>(lut_insts: LI) -> Vec + where + LI: IntoIterator>, + { + let lut_insts: Vec<&LutInstance> = lut_insts.into_iter().collect(); + + // Group all non-packed lookup families that share an address group. + // The addr_group is carried on each LutInstance (set by the bus config for trace mode, + // None when no lookup-family sharing is configured). This collapses per-column decode/width families into one + // gamma-batched claim pair while keeping packed/event-table specs on their existing + // per-lane schedule. + let mut grouped: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + let mut grouped_ell: std::collections::BTreeMap = std::collections::BTreeMap::new(); + + let mut flat_lane_idx = 0usize; + for (inst_idx, lut_inst) in lut_insts.iter().enumerate() { + let lanes = lut_inst.lanes.max(1); + let ell_addr = lut_inst.d * lut_inst.ell; + let is_packed = matches!( + lut_inst.table_spec, + Some(LutTableSpec::RiscvOpcodePacked { .. } | LutTableSpec::RiscvOpcodeEventTablePacked { .. }) + ); + let is_gamma_candidate = !is_packed && lut_inst.addr_group.is_some(); + for lane_idx in 0..lanes { + if is_gamma_candidate { + if let Some(addr_group) = lut_inst.addr_group { + let key = (addr_group << 32) | lane_idx as u64; + grouped + .entry(key) + .or_default() + .push(ShoutGammaGroupLaneRef { + flat_lane_idx, + inst_idx, + lane_idx, + }); + grouped_ell.entry(key).or_insert(ell_addr); + } + } + flat_lane_idx += 1; + } + } + + let mut out = Vec::new(); + for (key, lanes) in grouped.into_iter() { + if lanes.len() <= 1 { + continue; + } + if let Some(&ell_addr) = grouped_ell.get(&key) { + out.push(ShoutGammaGroupSpec { key, ell_addr, lanes }); + } + } + out + } + pub fn time_claim_metas_for_instances<'a, LI, MI>( lut_insts: LI, mem_insts: MI, ccs_time_degree_bound: usize, + wb_enabled: bool, + wp_enabled: bool, + decode_stage_enabled: bool, + width_stage_enabled: bool, + control_stage_enabled: bool, ob_inc_total_degree_bound: Option, ) -> Vec where LI: IntoIterator>, MI: IntoIterator>, { + let lut_insts: Vec<&LutInstance> = lut_insts.into_iter().collect(); + let mem_insts: Vec<&MemInstance> = mem_insts.into_iter().collect(); + let shout_gamma_groups = Self::derive_shout_gamma_groups_for_instances(lut_insts.iter().copied()); + let mut lane_gamma_map: std::collections::HashMap<(usize, usize), usize> = std::collections::HashMap::new(); + for (g_idx, g) in shout_gamma_groups.iter().enumerate() { + for lane in g.lanes.iter() { + lane_gamma_map.insert((lane.inst_idx, lane.lane_idx), g_idx); + } + } + let any_event_table_shout = lut_insts + .iter() + .any(|inst| matches!(inst.table_spec, Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. }))); + let mut out = Vec::new(); + let mut gamma_value_degree_bounds = vec![0usize; shout_gamma_groups.len()]; + let mut gamma_adapter_degree_bounds = vec![0usize; shout_gamma_groups.len()]; out.push(TimeClaimMeta { label: b"ccs/time", @@ -63,21 +179,59 @@ impl RouteATimeClaimPlan { is_dynamic: true, }); - for lut_inst in lut_insts { + for (inst_idx, lut_inst) in lut_insts.iter().enumerate() { let ell_addr = lut_inst.d * lut_inst.ell; let lanes = lut_inst.lanes.max(1); - - for _lane in 0..lanes { - out.push(TimeClaimMeta { - label: b"shout/value", - degree_bound: 3, - is_dynamic: true, - }); - out.push(TimeClaimMeta { - label: b"shout/adapter", - degree_bound: 2 + ell_addr, - is_dynamic: true, - }); + let (packed_opcode, _packed_base_ell_addr) = match &lut_inst.table_spec { + Some(LutTableSpec::RiscvOpcodePacked { opcode, xlen: 32 }) => (Some(*opcode), ell_addr), + Some(LutTableSpec::RiscvOpcodeEventTablePacked { + opcode, + xlen: 32, + time_bits, + }) => (Some(*opcode), ell_addr.saturating_sub(*time_bits)), + _ => (None, ell_addr), + }; + + let (value_degree_bound, adapter_degree_bound) = match packed_opcode { + Some(RiscvOpcode::And | RiscvOpcode::Andn | RiscvOpcode::Or | RiscvOpcode::Xor) => (8, 6), + Some(RiscvOpcode::Add | RiscvOpcode::Sub) => (3, 2), + Some(RiscvOpcode::Eq | RiscvOpcode::Neq) => (34, 3), + Some(RiscvOpcode::Mul) => (4, 2), + Some(RiscvOpcode::Mulh) => (4, 5), + Some(RiscvOpcode::Mulhu) => (4, 2), + Some(RiscvOpcode::Mulhsu) => (4, 4), + Some(RiscvOpcode::Slt) => (3, 3), + Some(RiscvOpcode::Divu | RiscvOpcode::Remu) => (5, 4), + Some(RiscvOpcode::Div | RiscvOpcode::Rem) => (7, 6), + Some(RiscvOpcode::Sll) => (8, 2), + Some(RiscvOpcode::Srl | RiscvOpcode::Sra) => (8, 8), + Some(RiscvOpcode::Sltu) => (3, 3), + _ => (3, 2 + ell_addr), + }; + + for lane_idx in 0..lanes { + if let Some(&g_idx) = lane_gamma_map.get(&(inst_idx, lane_idx)) { + gamma_value_degree_bounds[g_idx] = gamma_value_degree_bounds[g_idx].max(value_degree_bound); + gamma_adapter_degree_bounds[g_idx] = gamma_adapter_degree_bounds[g_idx].max(adapter_degree_bound); + } else { + out.push(TimeClaimMeta { + label: b"shout/value", + degree_bound: value_degree_bound, + is_dynamic: true, + }); + out.push(TimeClaimMeta { + label: b"shout/adapter", + degree_bound: adapter_degree_bound, + is_dynamic: true, + }); + } + if let Some(LutTableSpec::RiscvOpcodeEventTablePacked { time_bits, .. }) = &lut_inst.table_spec { + out.push(TimeClaimMeta { + label: b"shout/event_table_hash", + degree_bound: 2 + *time_bits, + is_dynamic: true, + }); + } } out.push(TimeClaimMeta { @@ -87,6 +241,27 @@ impl RouteATimeClaimPlan { }); } + for (g_idx, _) in shout_gamma_groups.iter().enumerate() { + out.push(TimeClaimMeta { + label: b"shout/value", + degree_bound: gamma_value_degree_bounds[g_idx], + is_dynamic: true, + }); + out.push(TimeClaimMeta { + label: b"shout/adapter", + degree_bound: gamma_adapter_degree_bounds[g_idx], + is_dynamic: true, + }); + } + + if any_event_table_shout { + out.push(TimeClaimMeta { + label: b"shout/event_trace_hash", + degree_bound: 3, + is_dynamic: true, + }); + } + for mem_inst in mem_insts { let ell_addr = mem_inst.d * mem_inst.ell; @@ -108,6 +283,81 @@ impl RouteATimeClaimPlan { }); } + if wb_enabled { + out.push(TimeClaimMeta { + label: b"wb/booleanity", + degree_bound: 3, + is_dynamic: false, + }); + } + + if wp_enabled { + out.push(TimeClaimMeta { + label: b"wp/quiescence", + degree_bound: 3, + is_dynamic: false, + }); + } + + if decode_stage_enabled { + out.push(TimeClaimMeta { + label: b"decode/fields", + degree_bound: 4, + is_dynamic: false, + }); + out.push(TimeClaimMeta { + label: b"decode/immediates", + degree_bound: 3, + is_dynamic: false, + }); + } + + if width_stage_enabled { + out.push(TimeClaimMeta { + label: b"width/bitness", + degree_bound: 3, + is_dynamic: false, + }); + out.push(TimeClaimMeta { + label: b"width/quiescence", + degree_bound: 3, + is_dynamic: false, + }); + out.push(TimeClaimMeta { + label: b"width/load_semantics", + degree_bound: 4, + is_dynamic: false, + }); + out.push(TimeClaimMeta { + label: b"width/store_semantics", + degree_bound: 4, + is_dynamic: false, + }); + } + + if control_stage_enabled { + out.push(TimeClaimMeta { + label: b"control/next_pc_linear", + degree_bound: 3, + is_dynamic: false, + }); + out.push(TimeClaimMeta { + label: b"control/next_pc_control", + degree_bound: 5, + is_dynamic: false, + }); + out.push(TimeClaimMeta { + label: b"control/branch_semantics", + degree_bound: 4, + is_dynamic: false, + }); + out.push(TimeClaimMeta { + label: b"control/writeback", + degree_bound: 4, + is_dynamic: false, + }); + } + if let Some(degree_bound) = ob_inc_total_degree_bound { out.push(TimeClaimMeta { label: crate::output_binding::OB_INC_TOTAL_LABEL, @@ -127,12 +377,22 @@ impl RouteATimeClaimPlan { pub fn time_claim_metas_for_step( step: &StepInstanceBundle, ccs_time_degree_bound: usize, + wb_enabled: bool, + wp_enabled: bool, + decode_stage_enabled: bool, + width_stage_enabled: bool, + control_stage_enabled: bool, ob_inc_total_degree_bound: Option, ) -> Vec { Self::time_claim_metas_for_instances( step.lut_insts.iter(), step.mem_insts.iter(), ccs_time_degree_bound, + wb_enabled, + wp_enabled, + decode_stage_enabled, + width_stage_enabled, + control_stage_enabled, ob_inc_total_degree_bound, ) } @@ -140,21 +400,59 @@ impl RouteATimeClaimPlan { pub fn build( step: &StepInstanceBundle, claim_idx_start: usize, + wb_enabled: bool, + wp_enabled: bool, + decode_stage_enabled: bool, + width_stage_enabled: bool, + control_stage_enabled: bool, ) -> Result { let mut idx = claim_idx_start; let mut shout = Vec::with_capacity(step.lut_insts.len()); + let shout_gamma_specs = Self::derive_shout_gamma_groups_for_instances(step.lut_insts.iter()); + let mut lane_gamma_map: std::collections::HashMap<(usize, usize), usize> = std::collections::HashMap::new(); + for (g_idx, g) in shout_gamma_specs.iter().enumerate() { + for lane in g.lanes.iter() { + lane_gamma_map.insert((lane.inst_idx, lane.lane_idx), g_idx); + } + } + let any_event_table_shout = step + .lut_insts + .iter() + .any(|inst| matches!(inst.table_spec, Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. }))); let mut twist = Vec::with_capacity(step.mem_insts.len()); - for lut_inst in &step.lut_insts { + for (inst_idx, lut_inst) in step.lut_insts.iter().enumerate() { let ell_addr = lut_inst.d * lut_inst.ell; let lanes = lut_inst.lanes.max(1); + let is_event_table = matches!( + lut_inst.table_spec, + Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. }) + ); let mut lane_claims: Vec = Vec::with_capacity(lanes); - for _lane in 0..lanes { - let value = idx; - idx += 1; - let adapter = idx; - idx += 1; - lane_claims.push(ShoutLaneTimeClaimIdx { value, adapter }); + for lane_idx in 0..lanes { + let gamma_group = lane_gamma_map.get(&(inst_idx, lane_idx)).copied(); + let (value, adapter) = if gamma_group.is_some() { + (None, None) + } else { + let value = idx; + idx += 1; + let adapter = idx; + idx += 1; + (Some(value), Some(adapter)) + }; + let event_table_hash = if is_event_table { + let h = idx; + idx += 1; + Some(h) + } else { + None + }; + lane_claims.push(ShoutLaneTimeClaimIdx { + value, + adapter, + event_table_hash, + gamma_group, + }); } let bitness = idx; idx += 1; @@ -166,6 +464,29 @@ impl RouteATimeClaimPlan { }); } + let mut shout_gamma_groups = Vec::with_capacity(shout_gamma_specs.len()); + for spec in shout_gamma_specs.into_iter() { + let value = idx; + idx += 1; + let adapter = idx; + idx += 1; + shout_gamma_groups.push(ShoutGammaGroupTimeClaimIdx { + key: spec.key, + ell_addr: spec.ell_addr, + lanes: spec.lanes, + value, + adapter, + }); + } + + let shout_event_trace_hash = if any_event_table_shout { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + for mem_inst in &step.mem_insts { let ell_addr = mem_inst.d * mem_inst.ell; let read_check = idx; @@ -184,6 +505,104 @@ impl RouteATimeClaimPlan { }); } + let wb_bool = if wb_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let wp_quiescence = if wp_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let decode_fields = if decode_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let decode_immediates = if decode_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let width_bitness = if width_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let width_quiescence = if width_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let width_selector_linkage = None; + + let width_load_semantics = if width_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let width_store_semantics = if width_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let control_next_pc_linear = if control_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let control_next_pc_control = if control_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let control_branch_semantics = if control_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + + let control_writeback = if control_stage_enabled { + let out = idx; + idx += 1; + Some(out) + } else { + None + }; + if idx < claim_idx_start { return Err(PiCcsError::ProtocolError("RouteATimeClaimPlan index underflow".into())); } @@ -192,7 +611,22 @@ impl RouteATimeClaimPlan { claim_idx_start, claim_idx_end: idx, shout, + shout_gamma_groups, + shout_event_trace_hash, twist, + wb_bool, + wp_quiescence, + decode_fields, + decode_immediates, + width_bitness, + width_quiescence, + width_selector_linkage, + width_load_semantics, + width_store_semantics, + control_next_pc_linear, + control_next_pc_control, + control_branch_semantics, + control_writeback, }) } } diff --git a/crates/neo-fold/src/memory_sidecar/cpu_bus.rs b/crates/neo-fold/src/memory_sidecar/cpu_bus.rs index 4c2afdfe..756eb26a 100644 --- a/crates/neo-fold/src/memory_sidecar/cpu_bus.rs +++ b/crates/neo-fold/src/memory_sidecar/cpu_bus.rs @@ -2,12 +2,16 @@ use crate::PiCcsError; use neo_ccs::{CcsMatrix, CcsStructure, Mat, MeInstance}; use neo_math::{F, K}; use neo_memory::ajtai::decode_vector as ajtai_decode_vector; -use neo_memory::cpu::{build_bus_layout_for_instances_with_shout_and_twist_lanes, BusLayout}; +use neo_memory::cpu::{ + build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes, BusLayout, ShoutInstanceShape, +}; +use neo_memory::riscv::lookups::{PROG_ID, REG_ID}; +use neo_memory::riscv::trace::{rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id}; use neo_memory::sparse_time::SparseIdxVec; use neo_memory::witness::{LutInstance, MemInstance, StepInstanceBundle, StepWitnessBundle}; use neo_params::NeoParams; use p3_field::PrimeCharacteristicRing; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; pub(crate) trait BusStepView { fn m_in(&self) -> usize; @@ -105,7 +109,6 @@ fn infer_bus_layout_for_steps>( } let chunk_size = infer_chunk_size_from_steps(steps)?; - let base_shout_ell_addrs: Vec = (0..steps[0].lut_insts_len()) .map(|i| { let inst = steps[0].lut_inst(i); @@ -118,6 +121,12 @@ fn infer_bus_layout_for_steps>( inst.lanes }) .collect(); + let base_shout_addr_groups: Vec> = (0..steps[0].lut_insts_len()) + .map(|i| steps[0].lut_inst(i).addr_group) + .collect(); + let base_shout_selector_groups: Vec> = (0..steps[0].lut_insts_len()) + .map(|i| steps[0].lut_inst(i).selector_group) + .collect(); let base_twist_ell_addrs: Vec = (0..steps[0].mem_insts_len()) .map(|i| { let inst = steps[0].mem_inst(i); @@ -144,6 +153,12 @@ fn infer_bus_layout_for_steps>( inst.lanes }) .collect(); + let cur_shout_addr_groups: Vec> = (0..step.lut_insts_len()) + .map(|j| step.lut_inst(j).addr_group) + .collect(); + let cur_shout_selector_groups: Vec> = (0..step.lut_insts_len()) + .map(|j| step.lut_inst(j).selector_group) + .collect(); let cur_twist: Vec = (0..step.mem_insts_len()) .map(|j| { let inst = step.mem_inst(j); @@ -158,6 +173,8 @@ fn infer_bus_layout_for_steps>( .collect(); if cur_shout != base_shout_ell_addrs || cur_shout_lanes != base_shout_lanes + || cur_shout_addr_groups != base_shout_addr_groups + || cur_shout_selector_groups != base_shout_selector_groups || cur_twist != base_twist_ell_addrs || cur_twist_lanes != base_twist_lanes { @@ -167,21 +184,29 @@ fn infer_bus_layout_for_steps>( } } - let shout_ell_addrs_and_lanes = base_shout_ell_addrs + let shout_shapes = base_shout_ell_addrs .iter() .copied() .zip(base_shout_lanes.iter().copied()) - .map(|(ell_addr, lanes)| (ell_addr, lanes)); + .zip(base_shout_addr_groups.iter().copied()) + .zip(base_shout_selector_groups.iter().copied()) + .map(|(((ell_addr, lanes), addr_group), selector_group)| ShoutInstanceShape { + ell_addr, + lanes, + n_vals: 1usize, + addr_group, + selector_group, + }); let twist_ell_addrs_and_lanes = base_twist_ell_addrs .iter() .copied() .zip(base_twist_lanes.iter().copied()) .map(|(ell_addr, lanes)| (ell_addr, lanes)); - let layout = build_bus_layout_for_instances_with_shout_and_twist_lanes( + let layout = build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes( s.m, m_in, chunk_size, - shout_ell_addrs_and_lanes, + shout_shapes, twist_ell_addrs_and_lanes, ) .map_err(PiCcsError::InvalidInput)?; @@ -212,7 +237,7 @@ pub(crate) fn prepare_ccs_for_shared_cpu_bus_steps<'a, Cmt, S: BusStepView> ) -> Result<(&'a CcsStructure, BusLayout), PiCcsError> { let bus = infer_bus_layout_for_steps(s0, steps)?; let padding_rows = ensure_ccs_has_shared_bus_padding_for_steps(s0, &bus, steps)?; - ensure_ccs_binds_shared_bus_for_steps(s0, &bus, &padding_rows)?; + ensure_ccs_binds_shared_bus_for_steps(s0, &bus, &padding_rows, steps)?; // Performance: do NOT materialize bus copyout matrices into the CCS. Instead, we append the // corresponding ME openings directly from the witness (see `append_bus_openings_to_me_*`). Ok((s0, bus)) @@ -230,6 +255,41 @@ fn chi_for_row_index(r: &[K], idx: usize) -> K { acc } +#[inline] +fn precompute_contiguous_time_weights(r: &[K], start_row: usize, len: usize, n_pad: usize) -> Vec { + if len == 0 { + return Vec::new(); + } + + // For large contiguous windows, build χ_r over the full boolean domain once and slice. + // This avoids repeated per-index basis recomputation in hot Route-A paths. + const FULL_CHI_MAX_PAD: usize = 1 << 20; + let naive_ops = len.saturating_mul(r.len().max(1)); + let use_full_table = len >= 1024 && n_pad <= FULL_CHI_MAX_PAD && naive_ops >= n_pad; + + if use_full_table { + let mut chi = Vec::with_capacity(n_pad); + chi.push(K::ONE); + for &ri in r { + let one_minus_ri = K::ONE - ri; + let cur_len = chi.len(); + chi.reserve(cur_len); + for i in 0..cur_len { + let v = chi[i]; + chi[i] = v * one_minus_ri; + chi.push(v * ri); + } + } + return chi[start_row..start_row + len].to_vec(); + } + + let mut out = Vec::with_capacity(len); + for off in 0..len { + out.push(chi_for_row_index(r, start_row + off)); + } + out +} + pub(crate) fn append_bus_openings_to_me_instance( params: &NeoParams, bus: &BusLayout, @@ -292,7 +352,7 @@ where let want_len = core_t .checked_add(bus.bus_cols) .ok_or_else(|| PiCcsError::InvalidInput("core_t + bus_cols overflow".into()))?; - if me.y.len() == want_len && me.y_scalars.len() == want_len { + if me.y.len() >= want_len && me.y_scalars.len() >= want_len && me.y.len() == me.y_scalars.len() { return Ok(()); } if me.y.len() != core_t || me.y_scalars.len() != core_t { @@ -314,10 +374,12 @@ where } // Precompute χ_r(time_index(j)) weights for the bus time rows. - let mut time_weights = Vec::with_capacity(bus.chunk_size); - for j in 0..bus.chunk_size { - time_weights.push(chi_for_row_index(&me.r, bus.time_index(j))); - } + let time_weights = precompute_contiguous_time_weights(&me.r, bus.time_index(0), bus.chunk_size, n_pad); + let weighted_rows: Vec<(usize, K)> = time_weights + .into_iter() + .enumerate() + .filter_map(|(j, w)| (w != K::ZERO).then_some((j, w))) + .collect(); // Base-b powers for recomposition. let bK = K::from(F::from_u64(params.b as u64)); @@ -331,23 +393,346 @@ where // Append bus openings in canonical col_id order so `bus_y_base = y_scalars.len() - bus_cols` // remains valid. for col_id in 0..bus.bus_cols { + let col_base = bus + .bus_base + .checked_add( + col_id + .checked_mul(bus.chunk_size) + .ok_or_else(|| PiCcsError::InvalidInput("bus col_id * chunk_size overflow".into()))?, + ) + .ok_or_else(|| PiCcsError::InvalidInput("bus col_base overflow".into()))?; let mut y_row = vec![K::ZERO; y_pad]; + let mut y_scalar = K::ZERO; for rho in 0..d { let mut acc = K::ZERO; - for j in 0..bus.chunk_size { - let w = time_weights[j]; - if w == K::ZERO { - continue; - } - let z_idx = bus.bus_cell(col_id, j); - acc += w * K::from(Z[(rho, z_idx)]); + for &(j, w) in weighted_rows.iter() { + acc += w * K::from(Z[(rho, col_base + j)]); } y_row[rho] = acc; + y_scalar += acc * pow_b[rho]; + } + + me.y.push(y_row); + me.y_scalars.push(y_scalar); + } + + Ok(()) +} + +/// Append time-indexed openings for a column-major region of the CPU witness. +/// +/// This is a "no shared CPU bus tail" bridge: instead of materializing copyout matrices for +/// per-row columns (e.g. an execution trace), we compute their Route-A time-combined openings +/// directly from the committed witness matrix `Z` and append them to the ME instance. +/// +/// Semantics: for each `col_id` in `cols`, append an opening for the vector +/// `{ z[col_base + col_id * t_len + j] }_{j=0..t_len-1}` combined with weights +/// `χ_r(m_in + j)` where `r = me.r`. +pub(crate) fn append_col_major_time_openings_to_me_instance( + params: &NeoParams, + m_in: usize, + t_len: usize, + col_base: usize, + cols: &[usize], + core_t: usize, + Z: &Mat, + me: &mut MeInstance, +) -> Result<(), PiCcsError> +where + Cmt: Clone, +{ + if cols.is_empty() { + return Ok(()); + } + if t_len == 0 { + return Err(PiCcsError::InvalidInput("trace openings require t_len >= 1".into())); + } + + let y_pad = (params.d as usize).next_power_of_two(); + let d = neo_math::D; + if y_pad < d { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require y_pad >= D (y_pad={y_pad}, D={d})" + ))); + } + if Z.rows() != d { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require Z.rows()==D (got {}, want {})", + Z.rows(), + d + ))); + } + if me.m_in != m_in { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require ME.m_in==m_in (got {}, want {})", + me.m_in, m_in + ))); + } + if me.r.is_empty() { + return Err(PiCcsError::InvalidInput("trace openings require non-empty ME.r".into())); + } + if col_base >= Z.cols() { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require col_base < m (col_base={}, m={})", + col_base, + Z.cols() + ))); + } + + let n_pad = 1usize + .checked_shl(me.r.len() as u32) + .ok_or_else(|| PiCcsError::InvalidInput("2^ell_n overflow".into()))?; + for j in 0..t_len { + let row = m_in + .checked_add(j) + .ok_or_else(|| PiCcsError::InvalidInput("m_in + t overflow".into()))?; + if row >= n_pad { + return Err(PiCcsError::InvalidInput(format!( + "trace time row index out of range: (m_in + j)={} out of range for ell_n={} (n_pad={})", + row, + me.r.len(), + n_pad + ))); + } + } + + // Idempotent append: allow callers to call this once; reject unexpected shapes. + let want_len = core_t + .checked_add(cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("core_t + cols.len overflow".into()))?; + if me.y.len() == want_len && me.y_scalars.len() == want_len { + return Ok(()); + } + if me.y.len() != core_t || me.y_scalars.len() != core_t { + return Err(PiCcsError::InvalidInput(format!( + "trace openings expect ME y/y_scalars to start at core_t (y.len()={}, y_scalars.len()={}, core_t={})", + me.y.len(), + me.y_scalars.len(), + core_t + ))); + } + for (j, row) in me.y.iter().enumerate() { + if row.len() != y_pad { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require ME.y[{j}].len()==y_pad (got {}, want {})", + row.len(), + y_pad + ))); + } + } + + // Precompute χ_r(m_in + j) weights for the time rows. + let time_weights = precompute_contiguous_time_weights(&me.r, m_in, t_len, n_pad); + let weighted_rows: Vec<(usize, K)> = time_weights + .into_iter() + .enumerate() + .filter_map(|(j, w)| (w != K::ZERO).then_some((j, w))) + .collect(); + + // Base-b powers for recomposition. + let bK = K::from(F::from_u64(params.b as u64)); + let mut pow_b = Vec::with_capacity(d); + let mut cur = K::ONE; + for _ in 0..d { + pow_b.push(cur); + cur *= bK; + } + + for &col_id in cols { + let col_offset = col_id + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace col_id * t_len overflow".into()))?; + let col_start = col_base + .checked_add(col_offset) + .ok_or_else(|| PiCcsError::InvalidInput("trace col_base + col_offset overflow".into()))?; + let col_end = col_start + .checked_add(t_len - 1) + .ok_or_else(|| PiCcsError::InvalidInput("trace col_end overflow".into()))?; + if col_end >= Z.cols() { + return Err(PiCcsError::InvalidInput(format!( + "trace openings: column span out of range (col_start={col_start}, col_end={col_end}, m={})", + Z.cols() + ))); } + let mut y_row = vec![K::ZERO; y_pad]; let mut y_scalar = K::ZERO; for rho in 0..d { - y_scalar += y_row[rho] * pow_b[rho]; + let mut acc = K::ZERO; + for (j, w) in weighted_rows.iter() { + acc += *w * K::from(Z[(rho, col_start + *j)]); + } + y_row[rho] = acc; + y_scalar += acc * pow_b[rho]; + } + + me.y.push(y_row); + me.y_scalars.push(y_scalar); + } + + Ok(()) +} + +/// Append time-indexed openings for a column-major region of the CPU witness, using only the +/// selected time rows `js`. +/// +/// This is valid when the caller knows that for each opened column `col_id`, all omitted rows +/// (`j` not in `js`) are zero in the witness; then the opening can be computed by summing only +/// over `js`. +pub(crate) fn append_col_major_time_openings_to_me_instance_at_js( + params: &NeoParams, + m_in: usize, + t_len: usize, + col_base: usize, + cols: &[usize], + core_t: usize, + Z: &Mat, + me: &mut MeInstance, + js: &[usize], +) -> Result<(), PiCcsError> +where + Cmt: Clone, +{ + if cols.is_empty() { + return Ok(()); + } + if t_len == 0 { + return Err(PiCcsError::InvalidInput("trace openings require t_len >= 1".into())); + } + + let y_pad = (params.d as usize).next_power_of_two(); + let d = neo_math::D; + if y_pad < d { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require y_pad >= D (y_pad={y_pad}, D={d})" + ))); + } + if Z.rows() != d { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require Z.rows()==D (got {}, want {})", + Z.rows(), + d + ))); + } + if me.m_in != m_in { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require ME.m_in==m_in (got {}, want {})", + me.m_in, m_in + ))); + } + if me.r.is_empty() { + return Err(PiCcsError::InvalidInput("trace openings require non-empty ME.r".into())); + } + if col_base >= Z.cols() { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require col_base < m (col_base={}, m={})", + col_base, + Z.cols() + ))); + } + + let n_pad = 1usize + .checked_shl(me.r.len() as u32) + .ok_or_else(|| PiCcsError::InvalidInput("2^ell_n overflow".into()))?; + for &j in js { + if j >= t_len { + return Err(PiCcsError::InvalidInput(format!( + "trace js out of range: j={j} >= t_len={t_len}" + ))); + } + let row = m_in + .checked_add(j) + .ok_or_else(|| PiCcsError::InvalidInput("m_in + j overflow".into()))?; + if row >= n_pad { + return Err(PiCcsError::InvalidInput(format!( + "trace time row index out of range: (m_in + j)={row} out of range for ell_n={} (n_pad={})", + me.r.len(), + n_pad + ))); + } + } + + // Idempotent append: allow callers to call this once; reject unexpected shapes. + let want_len = core_t + .checked_add(cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("core_t + cols.len overflow".into()))?; + if me.y.len() == want_len && me.y_scalars.len() == want_len { + return Ok(()); + } + if me.y.len() != core_t || me.y_scalars.len() != core_t { + return Err(PiCcsError::InvalidInput(format!( + "trace openings expect ME y/y_scalars to start at core_t (y.len()={}, y_scalars.len()={}, core_t={})", + me.y.len(), + me.y_scalars.len(), + core_t + ))); + } + for (j, row) in me.y.iter().enumerate() { + if row.len() != y_pad { + return Err(PiCcsError::InvalidInput(format!( + "trace openings require ME.y[{j}].len()==y_pad (got {}, want {})", + row.len(), + y_pad + ))); + } + } + + // Precompute χ_r(m_in + j) weights for the selected time rows. + let dense_selection = js.len().saturating_mul(3) >= t_len; + let time_weights: Vec<(usize, K)> = if dense_selection { + let all = precompute_contiguous_time_weights(&me.r, m_in, t_len, n_pad); + let mut out = Vec::with_capacity(js.len()); + for &j in js { + out.push((j, all[j])); + } + out + } else { + let mut out = Vec::with_capacity(js.len()); + for &j in js { + out.push((j, chi_for_row_index(&me.r, m_in + j))); + } + out + }; + let weighted_rows: Vec<(usize, K)> = time_weights + .into_iter() + .filter_map(|(j, w)| (w != K::ZERO).then_some((j, w))) + .collect(); + + // Base-b powers for recomposition. + let bK = K::from(F::from_u64(params.b as u64)); + let mut pow_b = Vec::with_capacity(d); + let mut cur = K::ONE; + for _ in 0..d { + pow_b.push(cur); + cur *= bK; + } + + for &col_id in cols { + let col_offset = col_id + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace col_id * t_len overflow".into()))?; + let col_start = col_base + .checked_add(col_offset) + .ok_or_else(|| PiCcsError::InvalidInput("trace col_base + col_offset overflow".into()))?; + let col_end = col_start + .checked_add(t_len - 1) + .ok_or_else(|| PiCcsError::InvalidInput("trace col_end overflow".into()))?; + if col_end >= Z.cols() { + return Err(PiCcsError::InvalidInput(format!( + "trace openings: column span out of range (col_start={col_start}, col_end={col_end}, m={})", + Z.cols() + ))); + } + + let mut y_row = vec![K::ZERO; y_pad]; + let mut y_scalar = K::ZERO; + for rho in 0..d { + let mut acc = K::ZERO; + for (j, w) in weighted_rows.iter() { + acc += *w * K::from(Z[(rho, col_start + *j)]); + } + y_row[rho] = acc; + y_scalar += acc * pow_b[rho]; } me.y.push(y_row); @@ -424,7 +809,7 @@ fn required_bus_cols_for_layout(layout: &BusLayout) -> Vec { label: format!("shout[{lut_idx}].lane[{lane_idx}].has_lookup"), }); out.push(BusColLabel { - col_id: shout.val, + col_id: shout.primary_val(), label: format!("shout[{lut_idx}].lane[{lane_idx}].val"), }); } @@ -470,7 +855,7 @@ fn required_bus_cols_for_layout(layout: &BusLayout) -> Vec { out } -fn required_bus_binding_cols_for_layout(layout: &BusLayout) -> Vec { +fn required_bus_binding_cols_for_layout>(layout: &BusLayout, steps: &[S]) -> Vec { // Note: `inc_at_write_addr` is a Twist-internal witness field derived from the sparse // memory state. Many CPU CCSes do not (and should not) constrain it outside padding rows; // it is constrained by the Twist Route-A checks instead. We still require the canonical @@ -480,9 +865,82 @@ fn required_bus_binding_cols_for_layout(layout: &BusLayout) -> Vec .iter() .flat_map(|inst| inst.lanes.iter().map(|t| t.inc)) .collect(); + + // Shout key `addr_bits` are often constrained outside the *main* CPU CCS: + // - by a decode/semantics sidecar CCS, and/or + // - by VM-specific constraints that live outside the shared-bus binding gadget. + // + // The Route-A Shout argument already constrains `(addr_bits, val)` internally via: + // - per-lane Shout value/adaptor terminal checks, and + // - trace linkage checks (`verify_route_a_memory_step`) that bind the + // CPU trace's `(shout_has_lookup, shout_val, shout_lhs, shout_rhs)` to the sidecar openings. + // + // In RV32 trace shared-bus mode, Shout table-linkage ownership is moved to reduction-time + // aggregate checks, so the shared-bus adapter may intentionally omit CPU-linkage equalities + // for Shout lanes. Keep only canonical Shout padding/bitness constraints in the CPU CCS and + // exempt all Shout columns from the "outside-padding binding" guard. + let shout_addr_cols: HashSet = layout + .shout_cols + .iter() + .flat_map(|inst| inst.lanes.iter().flat_map(|s| s.addr_bits.clone())) + .collect(); + let shout_selector_and_val_cols: HashSet = layout + .shout_cols + .iter() + .flat_map(|inst| { + inst.lanes + .iter() + .flat_map(|s| [s.has_lookup, s.primary_val()]) + }) + .collect(); + + let mut twist_unbound_cols: HashSet = HashSet::new(); + if let Some(step0) = steps.first() { + let has_trace_lookup_families = (0..step0.lut_insts_len()).any(|idx| { + let table_id = step0.lut_inst(idx).table_id; + rv32_is_decode_lookup_table_id(table_id) || rv32_is_width_lookup_table_id(table_id) + }); + if has_trace_lookup_families { + for (mem_idx, inst) in layout.twist_cols.iter().enumerate() { + let mem_id = step0.mem_inst(mem_idx).mem_id; + for (lane_idx, twist) in inst.lanes.iter().enumerate() { + let read_bound = if mem_id == PROG_ID.0 { + lane_idx == 0 + } else if mem_id == REG_ID.0 { + lane_idx <= 1 + } else { + lane_idx == 0 + }; + let write_bound = if mem_id == REG_ID.0 { + lane_idx == 0 + } else if mem_id == PROG_ID.0 { + false + } else { + lane_idx == 0 + }; + + if !read_bound { + twist_unbound_cols.insert(twist.has_read); + twist_unbound_cols.insert(twist.rv); + twist_unbound_cols.extend(twist.ra_bits.clone()); + } + if !write_bound { + twist_unbound_cols.insert(twist.has_write); + twist_unbound_cols.insert(twist.wv); + twist_unbound_cols.insert(twist.inc); + twist_unbound_cols.extend(twist.wa_bits.clone()); + } + } + } + } + } + required_bus_cols_for_layout(layout) .into_iter() .filter(|c| !inc_cols.contains(&c.col_id)) + .filter(|c| !shout_addr_cols.contains(&c.col_id)) + .filter(|c| !shout_selector_and_val_cols.contains(&c.col_id)) + .filter(|c| !twist_unbound_cols.contains(&c.col_id)) .collect() } @@ -653,12 +1111,21 @@ struct BusPaddingLabel { fn required_bus_padding_for_layout(bus: &BusLayout) -> Vec { let mut out = Vec::::new(); + let mut shout_addr_range_counts = std::collections::HashMap::<(usize, usize), usize>::new(); + for inst in bus.shout_cols.iter() { + for shout in inst.lanes.iter() { + let key = (shout.addr_bits.start, shout.addr_bits.end); + *shout_addr_range_counts.entry(key).or_insert(0) += 1; + } + } for (lut_idx, inst) in bus.shout_cols.iter().enumerate() { for (lane_idx, shout) in inst.lanes.iter().enumerate() { + let key = (shout.addr_bits.start, shout.addr_bits.end); + let shared_addr_group = shout_addr_range_counts.get(&key).copied().unwrap_or(0) > 1; for j in 0..bus.chunk_size { let has_lookup_z = bus.bus_cell(shout.has_lookup, j); - let val_z = bus.bus_cell(shout.val, j); + let val_z = bus.bus_cell(shout.primary_val(), j); // (1 - has_lookup) * val = 0 out.push(BusPaddingLabel { @@ -667,14 +1134,16 @@ fn required_bus_padding_for_layout(bus: &BusLayout) -> Vec { label: format!("shout[{lut_idx}].lane[{lane_idx}][j={j}]: (1-has_lookup)*val"), }); - // (1 - has_lookup) * addr_bits[b] = 0 - for (b, col_id) in shout.addr_bits.clone().enumerate() { - let bit_z = bus.bus_cell(col_id, j); - out.push(BusPaddingLabel { - flag_z_idx: has_lookup_z, - field_z_idx: bit_z, - label: format!("shout[{lut_idx}].lane[{lane_idx}][j={j}]: (1-has_lookup)*addr_bits[{b}]"), - }); + if !shared_addr_group { + // (1 - has_lookup) * addr_bits[b] = 0 + for (b, col_id) in shout.addr_bits.clone().enumerate() { + let bit_z = bus.bus_cell(col_id, j); + out.push(BusPaddingLabel { + flag_z_idx: has_lookup_z, + field_z_idx: bit_z, + label: format!("shout[{lut_idx}].lane[{lane_idx}][j={j}]: (1-has_lookup)*addr_bits[{b}]"), + }); + } } } } @@ -792,9 +1261,19 @@ fn ensure_ccs_has_bus_padding_constraints( )); } - // This validation is intentionally strict and recognizes the canonical R1CS embedding: - // A(z) * B(z) - C(z) = 0, with padding rows using: - // A(z) = (1 - flag), B(z) = field, C(z) = 0. + // This validation is intentionally strict and recognizes two safe patterns in the canonical + // R1CS embedding: + // + // 1) Explicit padding constraints: + // (1 - flag) * field = 0 + // + // 2) For *bit* fields only (addr bits), padding implied by a gated-bit constraint: + // bit * (bit - flag) = 0 + // combined with a boolean constraint on `flag`. When `flag ∈ {0,1}`, this implies: + // - flag=0 => bit=0 + // - flag=1 => bit ∈ {0,1} + // + // We accept (2) to allow circuits to drop redundant per-bit padding rows while staying safe. let (a_idx, b_idx, c_idx) = if s.matrices.len() >= 4 && s.matrices[0].is_identity() { (1usize, 2usize, 3usize) } else if s.matrices.len() >= 3 { @@ -806,11 +1285,14 @@ fn ensure_ccs_has_bus_padding_constraints( }; let n = s.n; - let empty = usize::MAX; - let multi = usize::MAX - 1; + let const_one_cols: HashSet = const_one_cols.iter().copied().collect(); let mut c_has_nonzero = vec![false; n]; - let mut b_col = vec![empty; n]; + let mut b_count = vec![0u8; n]; + let mut b_col1 = vec![0usize; n]; + let mut b_val1 = vec![F::ZERO; n]; + let mut b_col2 = vec![0usize; n]; + let mut b_val2 = vec![F::ZERO; n]; let mut a_count = vec![0u8; n]; let mut a_col1 = vec![0usize; n]; @@ -834,26 +1316,45 @@ fn ensure_ccs_has_bus_padding_constraints( } }; - let scan_b = |mat: &CcsMatrix, b_col: &mut [usize]| match mat { - CcsMatrix::Identity { n } => { - let cap = core::cmp::min(*n, b_col.len()); - for row in 0..cap { - b_col[row] = row; + let scan_b = |mat: &CcsMatrix, + b_count: &mut [u8], + b_col1: &mut [usize], + b_val1: &mut [F], + b_col2: &mut [usize], + b_val2: &mut [F]| { + match mat { + CcsMatrix::Identity { n } => { + let cap = core::cmp::min(*n, b_count.len()); + for row in 0..cap { + b_count[row] = 1; + b_col1[row] = row; + b_val1[row] = F::ONE; + } } - } - CcsMatrix::Csc(csc) => { - for col in 0..csc.ncols { - let s0 = csc.col_ptr[col]; - let e0 = csc.col_ptr[col + 1]; - for k in s0..e0 { - let row = csc.row_idx[k]; - if row >= b_col.len() { - continue; - } - if b_col[row] == empty { - b_col[row] = col; - } else { - b_col[row] = multi; + CcsMatrix::Csc(csc) => { + for col in 0..csc.ncols { + let s0 = csc.col_ptr[col]; + let e0 = csc.col_ptr[col + 1]; + for k in s0..e0 { + let row = csc.row_idx[k]; + if row >= b_count.len() { + continue; + } + match b_count[row] { + 0 => { + b_count[row] = 1; + b_col1[row] = col; + b_val1[row] = csc.vals[k]; + } + 1 => { + b_count[row] = 2; + b_col2[row] = col; + b_val2[row] = csc.vals[k]; + } + _ => { + b_count[row] = 3; + } + } } } } @@ -906,7 +1407,14 @@ fn ensure_ccs_has_bus_padding_constraints( }; scan_c(&s.matrices[c_idx], &mut c_has_nonzero); - scan_b(&s.matrices[b_idx], &mut b_col); + scan_b( + &s.matrices[b_idx], + &mut b_count, + &mut b_col1, + &mut b_val1, + &mut b_col2, + &mut b_val2, + ); scan_a( &s.matrices[a_idx], &mut a_count, @@ -916,16 +1424,47 @@ fn ensure_ccs_has_bus_padding_constraints( &mut a_val2, ); + // Find boolean constraints for flag columns: flag * (flag - 1) = 0. + let mut flag_is_boolean: HashSet = HashSet::new(); + let mut flag_boolean_row: HashMap = HashMap::new(); + let mut selector_rows: HashSet = HashSet::new(); + for row in 0..n { + if c_has_nonzero[row] { + continue; + } + if a_count[row] != 1 || a_val1[row] != F::ONE { + continue; + } + let flag_col = a_col1[row]; + if b_count[row] != 2 { + continue; + } + + let (c1, v1) = (b_col1[row], b_val1[row]); + let (c2, v2) = (b_col2[row], b_val2[row]); + + let ok = (c1 == flag_col && v1 == F::ONE && const_one_cols.contains(&c2) && v2 == -F::ONE) + || (c2 == flag_col && v2 == F::ONE && const_one_cols.contains(&c1) && v1 == -F::ONE); + if !ok { + continue; + } + + flag_is_boolean.insert(flag_col); + flag_boolean_row.entry(flag_col).or_insert(row); + selector_rows.insert(row); + } + + // Explicit padding constraints present in CCS: (1 - flag) * field = 0. let mut present: HashSet<(usize, usize)> = HashSet::new(); let mut padding_rows: HashSet = HashSet::new(); for row in 0..n { if c_has_nonzero[row] { continue; } - let field_col = b_col[row]; - if field_col == empty || field_col == multi { + if b_count[row] != 1 || b_val1[row] != F::ONE { continue; } + let field_col = b_col1[row]; if a_count[row] != 2 { continue; } @@ -950,14 +1489,67 @@ fn ensure_ccs_has_bus_padding_constraints( padding_rows.insert(row); } + // Bitness constraints that imply address-bit padding: bit * (bit - flag) = 0. + // + // We only treat these as satisfying padding when `flag` is also boolean. + let mut implied: HashMap<(usize, usize), usize> = HashMap::new(); + for row in 0..n { + if c_has_nonzero[row] { + continue; + } + if a_count[row] != 1 || a_val1[row] != F::ONE { + continue; + } + let bit_col = a_col1[row]; + if b_count[row] != 2 { + continue; + } + + let (c1, v1) = (b_col1[row], b_val1[row]); + let (c2, v2) = (b_col2[row], b_val2[row]); + + let flag_col = if c1 == bit_col && v1 == F::ONE && v2 == -F::ONE { + Some(c2) + } else if c2 == bit_col && v2 == F::ONE && v1 == -F::ONE { + Some(c1) + } else { + None + }; + let Some(flag_col) = flag_col else { + continue; + }; + if !flag_is_boolean.contains(&flag_col) { + continue; + } + + implied.insert((flag_col, bit_col), row); + } + let mut missing: Vec<&BusPaddingLabel> = Vec::new(); for req in required { + if present.contains(&(req.flag_z_idx, req.field_z_idx)) { + continue; + } + if implied.contains_key(&(req.flag_z_idx, req.field_z_idx)) { + continue; + } if !present.contains(&(req.flag_z_idx, req.field_z_idx)) { missing.push(req); } } if missing.is_empty() { + // Treat selector boolean and bitness constraints as padding-only rows so they don't + // accidentally satisfy "binding" presence checks. + for row in selector_rows { + padding_rows.insert(row); + } + for &row in implied.values() { + padding_rows.insert(row); + } + for &row in flag_boolean_row.values() { + padding_rows.insert(row); + } return Ok(padding_rows); } @@ -980,10 +1572,11 @@ fn ensure_ccs_has_bus_padding_constraints( ))) } -fn ensure_ccs_binds_shared_bus_for_steps( +fn ensure_ccs_binds_shared_bus_for_steps>( s: &CcsStructure, bus: &BusLayout, padding_rows: &HashSet, + steps: &[S], ) -> Result<(), PiCcsError> { if bus.bus_cols == 0 { return Ok(()); @@ -991,7 +1584,7 @@ fn ensure_ccs_binds_shared_bus_for_steps( let required = required_bus_cols_for_layout(bus); ensure_ccs_references_bus_cols(s, bus, &required)?; - let binding_required = required_bus_binding_cols_for_layout(bus); + let binding_required = required_bus_binding_cols_for_layout(bus, steps); ensure_ccs_references_bus_cols_outside_padding_rows(s, bus, padding_rows, &binding_required) } @@ -1056,3 +1649,42 @@ pub(crate) fn build_time_sparse_from_bus_col( } Ok(SparseIdxVec::from_entries(pow2_cycle, entries)) } + +pub(crate) fn build_time_sparse_from_bus_col_at_js( + z: &[K], + bus: &BusLayout, + col_id: usize, + js: &[usize], + pow2_cycle: usize, +) -> Result, PiCcsError> { + if col_id >= bus.bus_cols { + return Err(PiCcsError::InvalidInput(format!( + "bus col_id out of range: {col_id} >= {}", + bus.bus_cols + ))); + } + let mut entries: Vec<(usize, K)> = Vec::new(); + for &j in js { + if j >= bus.chunk_size { + return Err(PiCcsError::InvalidInput(format!( + "bus j out of range: j={j} >= bus.chunk_size={}", + bus.chunk_size + ))); + } + let t = bus.time_index(j); + if t >= pow2_cycle { + return Err(PiCcsError::InvalidInput(format!( + "bus time index out of range: t={t} >= pow2_cycle={pow2_cycle}" + ))); + } + let idx = bus.bus_cell(col_id, j); + let v = z + .get(idx) + .copied() + .ok_or_else(|| PiCcsError::InvalidInput(format!("CPU witness too short for bus idx={idx}")))?; + if v != K::ZERO { + entries.push((t, v)); + } + } + Ok(SparseIdxVec::from_entries(pow2_cycle, entries)) +} diff --git a/crates/neo-fold/src/memory_sidecar/cpu_bus_tests.rs b/crates/neo-fold/src/memory_sidecar/cpu_bus_tests.rs index dced75a8..adfd5aff 100644 --- a/crates/neo-fold/src/memory_sidecar/cpu_bus_tests.rs +++ b/crates/neo-fold/src/memory_sidecar/cpu_bus_tests.rs @@ -1,10 +1,18 @@ #![allow(non_snake_case)] use super::cpu_bus::append_bus_openings_to_me_instance; +use super::cpu_bus::prepare_ccs_for_shared_cpu_bus_steps; use neo_ajtai::Commitment; +use neo_ccs::poly::{SparsePoly, Term}; +use neo_ccs::relations::CcsStructure; use neo_ccs::Mat; use neo_math::{D, F, K}; use neo_memory::cpu::build_bus_layout_for_instances; +use neo_memory::cpu::constraints::{ + CpuConstraint, CpuConstraintBuilder, CpuConstraintLabel, ShoutCpuBinding, TwistCpuBinding, +}; +use neo_memory::mem_init::MemInit; +use neo_memory::witness::{LutInstance, MemInstance, StepInstanceBundle}; use neo_params::NeoParams; use p3_field::PrimeCharacteristicRing; @@ -104,3 +112,317 @@ fn append_bus_openings_matches_manual_for_chunk_size_2() { assert_eq!(me.y_scalars[j_idx], expect_scalar, "col_id={col_id}"); } } + +fn build_identity_first_r1cs_ccs_from_cpu_constraints( + n: usize, + m: usize, + const_one_col: usize, + constraints: &[CpuConstraint], +) -> CcsStructure { + assert!(constraints.len() <= n, "too many constraints for n"); + + let mut a_data = vec![F::ZERO; n * m]; + let mut b_data = vec![F::ZERO; n * m]; + let c_data = vec![F::ZERO; n * m]; + + for (row, constraint) in constraints.iter().enumerate() { + if constraint.negate_condition { + a_data[row * m + const_one_col] = F::ONE; + a_data[row * m + constraint.condition_col] = -F::ONE; + for &col in &constraint.additional_condition_cols { + a_data[row * m + col] = -F::ONE; + } + } else { + a_data[row * m + constraint.condition_col] = F::ONE; + for &col in &constraint.additional_condition_cols { + a_data[row * m + col] = F::ONE; + } + } + + for &(col, coeff) in &constraint.b_terms { + b_data[row * m + col] += coeff; + } + } + + let i_n = Mat::identity(n); + let a = Mat::from_row_major(n, m, a_data); + let b = Mat::from_row_major(n, m, b_data); + let c = Mat::from_row_major(n, m, c_data); + + // Identity-first R1CS embedding: f(x0, x1, x2, x3) = x1*x2 - x3. + let f = SparsePoly::new( + 4, + vec![ + Term { + coeff: F::ONE, + exps: vec![0, 1, 1, 0], + }, + Term { + coeff: -F::ONE, + exps: vec![0, 0, 0, 1], + }, + ], + ); + + CcsStructure::new(vec![i_n, a, b, c], f).expect("build CCS from constraints") +} + +fn minimal_bus_steps( + m_in: usize, + chunk_size: usize, + shout_d: usize, + shout_ell: usize, + twist_d: usize, + twist_ell: usize, +) -> Vec> { + let mut x = vec![F::ONE; m_in]; + if m_in == 0 { + x = Vec::new(); + } + let mcs_inst = neo_ccs::McsInstance:: { + c: Commitment::zeros(1, 1), + x, + m_in, + }; + + let lut = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 1usize << shout_d, + d: shout_d, + n_side: 2, + steps: chunk_size, + lanes: 1, + ell: shout_ell, + table_spec: None, + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mem = MemInstance:: { + mem_id: 0, + comms: Vec::new(), + k: 1usize << twist_d, + d: twist_d, + n_side: 2, + steps: chunk_size, + lanes: 1, + ell: twist_ell, + init: MemInit::Zero, + }; + + let mut step: StepInstanceBundle = StepInstanceBundle::from(mcs_inst); + step.lut_insts = vec![lut]; + step.mem_insts = vec![mem]; + vec![step] +} + +#[test] +fn shared_cpu_bus_padding_validator_accepts_implied_addr_bit_padding() { + let m_in = 1usize; + let chunk_size = 1usize; + let shout_d = 2usize; + let shout_ell = 1usize; + let twist_d = 2usize; + let twist_ell = 1usize; + + // Minimal CPU+bus witness shape: + // - m_in=1 (const-one public input) + // - 8 CPU columns for bindings + // - bus tail for 1 shout (ell_addr=2) + 1 twist (ell_addr=2) + let cpu_cols = 8usize; + let bus_cols = (shout_d * shout_ell + 2) + (2 * (twist_d * twist_ell) + 5); + let m = m_in + cpu_cols + bus_cols; + let n = m; + let const_one_col = 0usize; + + let bus = build_bus_layout_for_instances(m, m_in, chunk_size, [shout_d * shout_ell], [twist_d * twist_ell]) + .expect("bus layout"); + assert!(bus.bus_cols > 0); + + // Bindings live in the CPU region immediately before the bus tail. + let cpu_has_read = m_in; + let cpu_has_write = m_in + 1; + let cpu_read_addr = m_in + 2; + let cpu_write_addr = m_in + 3; + let cpu_rv = m_in + 4; + let cpu_wv = m_in + 5; + let cpu_has_lookup = m_in + 6; + let cpu_lookup_val = m_in + 7; + + let shout = &bus.shout_cols[0].lanes[0]; + let twist = &bus.twist_cols[0].lanes[0]; + + let mut builder = CpuConstraintBuilder::::new(n, m, const_one_col); + builder.add_shout_instance_bound( + &bus, + shout, + &ShoutCpuBinding { + has_lookup: cpu_has_lookup, + addr: None, + val: cpu_lookup_val, + }, + ); + builder.add_twist_instance_bound( + &bus, + twist, + &TwistCpuBinding { + has_read: cpu_has_read, + has_write: cpu_has_write, + read_addr: cpu_read_addr, + write_addr: cpu_write_addr, + rv: cpu_rv, + wv: cpu_wv, + inc: None, + }, + ); + + // Use the builder's constraints but rebuild CCS locally so we can easily mutate the list in negative tests. + let constraints = builder.constraints().to_vec(); + let ccs = build_identity_first_r1cs_ccs_from_cpu_constraints(n, m, const_one_col, &constraints); + + let steps = minimal_bus_steps(m_in, chunk_size, shout_d, shout_ell, twist_d, twist_ell); + prepare_ccs_for_shared_cpu_bus_steps(&ccs, &steps).expect("padding validator should accept"); +} + +#[test] +fn shared_cpu_bus_padding_validator_requires_flag_boolean_for_implied_padding() { + let m_in = 1usize; + let chunk_size = 1usize; + let shout_d = 2usize; + let shout_ell = 1usize; + let twist_d = 2usize; + let twist_ell = 1usize; + + let cpu_cols = 8usize; + let bus_cols = (shout_d * shout_ell + 2) + (2 * (twist_d * twist_ell) + 5); + let m = m_in + cpu_cols + bus_cols; + let n = m; + let const_one_col = 0usize; + + let bus = build_bus_layout_for_instances(m, m_in, chunk_size, [shout_d * shout_ell], [twist_d * twist_ell]) + .expect("bus layout"); + + let cpu_has_read = m_in; + let cpu_has_write = m_in + 1; + let cpu_read_addr = m_in + 2; + let cpu_write_addr = m_in + 3; + let cpu_rv = m_in + 4; + let cpu_wv = m_in + 5; + let cpu_has_lookup = m_in + 6; + let cpu_lookup_val = m_in + 7; + + let shout = &bus.shout_cols[0].lanes[0]; + let twist = &bus.twist_cols[0].lanes[0]; + + let mut builder = CpuConstraintBuilder::::new(n, m, const_one_col); + builder.add_shout_instance_bound( + &bus, + shout, + &ShoutCpuBinding { + has_lookup: cpu_has_lookup, + addr: None, + val: cpu_lookup_val, + }, + ); + builder.add_twist_instance_bound( + &bus, + twist, + &TwistCpuBinding { + has_read: cpu_has_read, + has_write: cpu_has_write, + read_addr: cpu_read_addr, + write_addr: cpu_write_addr, + rv: cpu_rv, + wv: cpu_wv, + inc: None, + }, + ); + + let constraints: Vec> = builder + .constraints() + .iter() + .cloned() + .filter(|c| c.label != CpuConstraintLabel::ShoutHasLookupBoolean) + .collect(); + // Sanity: we actually removed something. + assert!(constraints.len() < builder.constraints().len()); + let ccs = build_identity_first_r1cs_ccs_from_cpu_constraints(n, m, const_one_col, &constraints); + + let steps = minimal_bus_steps(m_in, chunk_size, shout_d, shout_ell, twist_d, twist_ell); + assert!( + prepare_ccs_for_shared_cpu_bus_steps(&ccs, &steps).is_err(), + "validator unexpectedly accepted implied padding without a boolean constraint on has_lookup" + ); +} + +#[test] +fn shared_cpu_bus_padding_validator_requires_explicit_padding_for_nonbit_fields() { + let m_in = 1usize; + let chunk_size = 1usize; + let shout_d = 2usize; + let shout_ell = 1usize; + let twist_d = 2usize; + let twist_ell = 1usize; + + let cpu_cols = 8usize; + let bus_cols = (shout_d * shout_ell + 2) + (2 * (twist_d * twist_ell) + 5); + let m = m_in + cpu_cols + bus_cols; + let n = m; + let const_one_col = 0usize; + + let bus = build_bus_layout_for_instances(m, m_in, chunk_size, [shout_d * shout_ell], [twist_d * twist_ell]) + .expect("bus layout"); + + let cpu_has_read = m_in; + let cpu_has_write = m_in + 1; + let cpu_read_addr = m_in + 2; + let cpu_write_addr = m_in + 3; + let cpu_rv = m_in + 4; + let cpu_wv = m_in + 5; + let cpu_has_lookup = m_in + 6; + let cpu_lookup_val = m_in + 7; + + let shout = &bus.shout_cols[0].lanes[0]; + let twist = &bus.twist_cols[0].lanes[0]; + + let mut builder = CpuConstraintBuilder::::new(n, m, const_one_col); + builder.add_shout_instance_bound( + &bus, + shout, + &ShoutCpuBinding { + has_lookup: cpu_has_lookup, + addr: None, + val: cpu_lookup_val, + }, + ); + builder.add_twist_instance_bound( + &bus, + twist, + &TwistCpuBinding { + has_read: cpu_has_read, + has_write: cpu_has_write, + read_addr: cpu_read_addr, + write_addr: cpu_write_addr, + rv: cpu_rv, + wv: cpu_wv, + inc: None, + }, + ); + + let constraints: Vec> = builder + .constraints() + .iter() + .cloned() + .filter(|c| c.label != CpuConstraintLabel::ReadValueZeroPadding) + .collect(); + assert!(constraints.len() < builder.constraints().len()); + let ccs = build_identity_first_r1cs_ccs_from_cpu_constraints(n, m, const_one_col, &constraints); + + let steps = minimal_bus_steps(m_in, chunk_size, shout_d, shout_ell, twist_d, twist_ell); + assert!( + prepare_ccs_for_shared_cpu_bus_steps(&ccs, &steps).is_err(), + "validator unexpectedly accepted missing explicit padding for rv" + ); +} diff --git a/crates/neo-fold/src/memory_sidecar/memory.rs b/crates/neo-fold/src/memory_sidecar/memory.rs index 5cb93d70..8771316d 100644 --- a/crates/neo-fold/src/memory_sidecar/memory.rs +++ b/crates/neo-fold/src/memory_sidecar/memory.rs @@ -8,2920 +8,78 @@ use crate::shard_proof_types::{ use crate::PiCcsError; use neo_ajtai::Commitment as Cmt; use neo_ccs::{CcsStructure, MeInstance}; -use neo_math::{F, K}; +use neo_math::{KExtensions, F, K}; use neo_memory::bit_ops::{eq_bit_affine, eq_bits_prod}; -use neo_memory::cpu::BusLayout; +use neo_memory::cpu::{ + build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes, BusLayout, ShoutInstanceShape, +}; use neo_memory::identity::shout_oracle::IdentityAddressLookupOracleSparse; use neo_memory::mle::{eq_points, lt_eval}; use neo_memory::riscv::shout_oracle::RiscvAddressLookupOracleSparse; +use neo_memory::riscv::trace::{ + rv32_decode_lookup_backed_cols, rv32_decode_lookup_backed_row_from_instr_word, rv32_decode_lookup_table_id_for_col, + rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id, rv32_width_lookup_backed_cols, + rv32_width_lookup_table_id_for_col, Rv32DecodeSidecarLayout, Rv32TraceLayout, Rv32WidthSidecarLayout, +}; use neo_memory::sparse_time::SparseIdxVec; use neo_memory::ts_common as ts; use neo_memory::twist_oracle::{ - AddressLookupOracle, IndexAdapterOracleSparseTime, LazyWeightedBitnessOracleSparseTime, ShoutValueOracleSparse, - TwistLaneSparseCols, TwistReadCheckAddrOracleSparseTimeMultiLane, TwistReadCheckOracleSparseTime, - TwistTotalIncOracleSparseTime, TwistValEvalOracleSparseTime, TwistWriteCheckAddrOracleSparseTimeMultiLane, - TwistWriteCheckOracleSparseTime, + AddressLookupOracle, IndexAdapterOracleSparseTime, LazyWeightedBitnessOracleSparseTime, + Rv32PackedAddOracleSparseTime, Rv32PackedAndOracleSparseTime, Rv32PackedAndnOracleSparseTime, + Rv32PackedBitwiseAdapterOracleSparseTime, Rv32PackedDivOracleSparseTime, Rv32PackedDivRemAdapterOracleSparseTime, + Rv32PackedDivRemuAdapterOracleSparseTime, Rv32PackedDivuOracleSparseTime, Rv32PackedEqAdapterOracleSparseTime, + Rv32PackedEqOracleSparseTime, Rv32PackedMulHiOracleSparseTime, Rv32PackedMulOracleSparseTime, + Rv32PackedMulhAdapterOracleSparseTime, Rv32PackedMulhsuAdapterOracleSparseTime, Rv32PackedMulhuOracleSparseTime, + Rv32PackedNeqAdapterOracleSparseTime, Rv32PackedNeqOracleSparseTime, Rv32PackedOrOracleSparseTime, + Rv32PackedRemOracleSparseTime, Rv32PackedRemuOracleSparseTime, Rv32PackedSllOracleSparseTime, + Rv32PackedSltOracleSparseTime, Rv32PackedSltuOracleSparseTime, Rv32PackedSraAdapterOracleSparseTime, + Rv32PackedSraOracleSparseTime, Rv32PackedSrlAdapterOracleSparseTime, Rv32PackedSrlOracleSparseTime, + Rv32PackedSubOracleSparseTime, Rv32PackedXorOracleSparseTime, ShoutValueOracleSparse, TwistLaneSparseCols, + TwistReadCheckAddrOracleSparseTimeMultiLane, TwistReadCheckOracleSparseTime, TwistTotalIncOracleSparseTime, + TwistValEvalOracleSparseTime, TwistWriteCheckAddrOracleSparseTimeMultiLane, TwistWriteCheckOracleSparseTime, + U32DecompOracleSparseTime, ZeroOracleSparseTime, }; use neo_memory::witness::{LutInstance, LutTableSpec, MemInstance, StepInstanceBundle, StepWitnessBundle}; use neo_memory::{eval_init_at_r_addr, twist, BatchedAddrProof, MemInit}; use neo_params::NeoParams; use neo_reductions::sumcheck::{BatchedClaim, RoundOracle}; use neo_transcript::{Poseidon2Transcript, Transcript}; +use p3_field::Field; use p3_field::PrimeCharacteristicRing; - -// ============================================================================ -// Transcript binding -// ============================================================================ - -fn bind_shout_table_spec(tr: &mut Poseidon2Transcript, spec: &Option) { - let Some(spec) = spec else { - return; - }; - - tr.append_message(b"shout/table_spec/tag", &[1u8]); - match spec { - LutTableSpec::RiscvOpcode { opcode, xlen } => { - // Stable numeric encoding: align with `RiscvShoutTables::opcode_to_id`. - let opcode_id: u64 = match opcode { - neo_memory::riscv::lookups::RiscvOpcode::And => 0, - neo_memory::riscv::lookups::RiscvOpcode::Xor => 1, - neo_memory::riscv::lookups::RiscvOpcode::Or => 2, - neo_memory::riscv::lookups::RiscvOpcode::Add => 3, - neo_memory::riscv::lookups::RiscvOpcode::Sub => 4, - neo_memory::riscv::lookups::RiscvOpcode::Slt => 5, - neo_memory::riscv::lookups::RiscvOpcode::Sltu => 6, - neo_memory::riscv::lookups::RiscvOpcode::Sll => 7, - neo_memory::riscv::lookups::RiscvOpcode::Srl => 8, - neo_memory::riscv::lookups::RiscvOpcode::Sra => 9, - neo_memory::riscv::lookups::RiscvOpcode::Eq => 10, - neo_memory::riscv::lookups::RiscvOpcode::Neq => 11, - neo_memory::riscv::lookups::RiscvOpcode::Mul => 12, - neo_memory::riscv::lookups::RiscvOpcode::Mulh => 13, - neo_memory::riscv::lookups::RiscvOpcode::Mulhu => 14, - neo_memory::riscv::lookups::RiscvOpcode::Mulhsu => 15, - neo_memory::riscv::lookups::RiscvOpcode::Div => 16, - neo_memory::riscv::lookups::RiscvOpcode::Divu => 17, - neo_memory::riscv::lookups::RiscvOpcode::Rem => 18, - neo_memory::riscv::lookups::RiscvOpcode::Remu => 19, - neo_memory::riscv::lookups::RiscvOpcode::Addw => 20, - neo_memory::riscv::lookups::RiscvOpcode::Subw => 21, - neo_memory::riscv::lookups::RiscvOpcode::Sllw => 22, - neo_memory::riscv::lookups::RiscvOpcode::Srlw => 23, - neo_memory::riscv::lookups::RiscvOpcode::Sraw => 24, - neo_memory::riscv::lookups::RiscvOpcode::Mulw => 25, - neo_memory::riscv::lookups::RiscvOpcode::Divw => 26, - neo_memory::riscv::lookups::RiscvOpcode::Divuw => 27, - neo_memory::riscv::lookups::RiscvOpcode::Remw => 28, - neo_memory::riscv::lookups::RiscvOpcode::Remuw => 29, - neo_memory::riscv::lookups::RiscvOpcode::Andn => 30, - }; - - tr.append_message(b"shout/table_spec/riscv/tag", &[1u8]); - tr.append_message(b"shout/table_spec/riscv/opcode_id", &opcode_id.to_le_bytes()); - tr.append_message(b"shout/table_spec/riscv/xlen", &(*xlen as u64).to_le_bytes()); - } - LutTableSpec::IdentityU32 => { - tr.append_message(b"shout/table_spec/identity_u32/tag", &[1u8]); - } - } -} - -fn absorb_step_memory_impl<'a, LI, MI>(tr: &mut Poseidon2Transcript, mut lut_insts: LI, mut mem_insts: MI) -where - LI: ExactSizeIterator>, - MI: ExactSizeIterator>, -{ - tr.append_message(b"step/absorb_memory_start", &[]); - tr.append_message(b"step/lut_count", &(lut_insts.len() as u64).to_le_bytes()); - for (i, inst) in lut_insts.by_ref().enumerate() { - // Bind public LUT parameters before any challenges. - tr.append_message(b"step/lut_idx", &(i as u64).to_le_bytes()); - tr.append_message(b"shout/k", &(inst.k as u64).to_le_bytes()); - tr.append_message(b"shout/d", &(inst.d as u64).to_le_bytes()); - tr.append_message(b"shout/n_side", &(inst.n_side as u64).to_le_bytes()); - tr.append_message(b"shout/steps", &(inst.steps as u64).to_le_bytes()); - tr.append_message(b"shout/ell", &(inst.ell as u64).to_le_bytes()); - tr.append_message(b"shout/lanes", &(inst.lanes.max(1) as u64).to_le_bytes()); - bind_shout_table_spec(tr, &inst.table_spec); - let table_digest = digest_fields(b"shout/table", &inst.table); - tr.append_message(b"shout/table_digest", &table_digest); - } - tr.append_message(b"step/mem_count", &(mem_insts.len() as u64).to_le_bytes()); - for (i, inst) in mem_insts.by_ref().enumerate() { - // Bind public memory parameters before any challenges. - tr.append_message(b"step/mem_idx", &(i as u64).to_le_bytes()); - tr.append_message(b"twist/k", &(inst.k as u64).to_le_bytes()); - tr.append_message(b"twist/d", &(inst.d as u64).to_le_bytes()); - tr.append_message(b"twist/n_side", &(inst.n_side as u64).to_le_bytes()); - tr.append_message(b"twist/steps", &(inst.steps as u64).to_le_bytes()); - tr.append_message(b"twist/ell", &(inst.ell as u64).to_le_bytes()); - tr.append_message(b"twist/lanes", &(inst.lanes.max(1) as u64).to_le_bytes()); - let init_digest = match &inst.init { - MemInit::Zero => digest_fields(b"twist/init/zero", &[]), - MemInit::Sparse(pairs) => { - let mut fs = Vec::with_capacity(2 * pairs.len()); - for (addr, val) in pairs.iter() { - fs.push(F::from_u64(*addr)); - fs.push(*val); - } - digest_fields(b"twist/init/sparse", &fs) - } - }; - tr.append_message(b"twist/init_digest", &init_digest); - } - tr.append_message(b"step/absorb_memory_done", &[]); -} - -pub fn absorb_step_memory(tr: &mut Poseidon2Transcript, step: &StepInstanceBundle) { - absorb_step_memory_impl(tr, step.lut_insts.iter(), step.mem_insts.iter()); -} - -pub(crate) fn absorb_step_memory_witness(tr: &mut Poseidon2Transcript, step: &StepWitnessBundle) { - absorb_step_memory_impl( - tr, - step.lut_instances.iter().map(|(inst, _)| inst), - step.mem_instances.iter().map(|(inst, _)| inst), - ); -} - -// ============================================================================ -// Prover helpers -// ============================================================================ - -pub(crate) struct ShoutDecodedColsSparse { - pub lanes: Vec, -} - -pub(crate) struct ShoutLaneSparseCols { - pub addr_bits: Vec>, - pub has_lookup: SparseIdxVec, - pub val: SparseIdxVec, -} - -pub(crate) struct TwistDecodedColsSparse { - pub lanes: Vec, -} - -pub(crate) struct SumRoundOracle { - oracles: Vec>, - num_rounds: usize, - degree_bound: usize, -} - -impl SumRoundOracle { - pub(crate) fn new(oracles: Vec>) -> Self { - if oracles.is_empty() { - panic!("SumRoundOracle requires at least one oracle"); - } - - let num_rounds = oracles[0].num_rounds(); - let degree_bound = oracles[0].degree_bound(); - for (idx, o) in oracles.iter().enumerate().skip(1) { - if o.num_rounds() != num_rounds { - panic!( - "SumRoundOracle num_rounds mismatch at idx={idx} (got {}, expected {num_rounds})", - o.num_rounds() - ); - } - if o.degree_bound() != degree_bound { - panic!( - "SumRoundOracle degree_bound mismatch at idx={idx} (got {}, expected {degree_bound})", - o.degree_bound() - ); - } - } - - Self { - oracles, - num_rounds, - degree_bound, - } - } -} - -impl RoundOracle for SumRoundOracle { - fn evals_at(&mut self, points: &[K]) -> Vec { - let mut acc = vec![K::ZERO; points.len()]; - for o in self.oracles.iter_mut() { - let ys = o.evals_at(points); - if ys.len() != acc.len() { - panic!( - "SumRoundOracle eval length mismatch (got {}, expected {})", - ys.len(), - acc.len() - ); - } - for (a, y) in acc.iter_mut().zip(ys) { - *a += y; - } - } - acc - } - - fn num_rounds(&self) -> usize { - self.num_rounds - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - for o in self.oracles.iter_mut() { - o.fold(r); - } - self.num_rounds = self.oracles[0].num_rounds(); - } -} - -fn build_twist_inc_terms_at_r_addr(lanes: &[TwistLaneSparseCols], r_addr: &[K]) -> Vec<(usize, K)> { - let ell_addr = r_addr.len(); - let mut out: Vec<(usize, K)> = Vec::new(); - - for lane in lanes.iter() { - debug_assert_eq!(lane.wa_bits.len(), ell_addr, "wa_bits len mismatch"); - for &(t, has_w) in lane.has_write.entries() { - let inc_t = lane.inc_at_write_addr.get(t); - if has_w == K::ZERO || inc_t == K::ZERO { - continue; - } - - let mut eq_addr = K::ONE; - for (b, col) in lane.wa_bits.iter().enumerate() { - let bit = col.get(t); - eq_addr *= eq_bit_affine(bit, r_addr[b]); - } - - let inc_at_r_addr = has_w * inc_t * eq_addr; - if inc_at_r_addr != K::ZERO { - out.push((t, inc_at_r_addr)); - } - } - } - - out -} - -pub struct RouteAShoutTimeOracles { - pub lanes: Vec, - pub bitness: Vec>, - pub ell_addr: usize, -} - -pub struct RouteAShoutTimeLaneOracles { - pub value: Box, - pub value_claim: K, - pub adapter: Box, - pub adapter_claim: K, -} - -pub struct RouteATwistTimeOracles { - pub read_check: Box, - pub write_check: Box, - pub bitness: Vec>, - pub ell_addr: usize, -} - -pub struct RouteAMemoryOracles { - pub shout: Vec, - pub twist: Vec, -} - -pub trait TimeBatchedClaims { - fn append_time_claims<'a>( - &'a mut self, - ell_n: usize, - claimed_sums: &mut Vec, - degree_bounds: &mut Vec, - labels: &mut Vec<&'static [u8]>, - claim_is_dynamic: &mut Vec, - claims: &mut Vec>, - ); -} - -pub(crate) struct ShoutAddrPreBatchProverData { - pub addr_pre: ShoutAddrPreProof, - pub decoded: Vec, -} - -pub struct ShoutAddrPreVerifyData { - pub is_active: bool, - pub addr_claim_sum: K, - pub addr_final: K, - pub r_addr: Vec, - pub table_eval_at_r_addr: K, -} - -pub(crate) struct TwistAddrPreProverData { - pub addr_pre: BatchedAddrProof, - pub decoded: TwistDecodedColsSparse, - /// Time-lane claimed sum for the read-check oracle (output of addr-pre). - pub read_check_claim_sum: K, - /// Time-lane claimed sum for the write-check oracle (output of addr-pre). - pub write_check_claim_sum: K, -} - -pub struct TwistAddrPreVerifyData { - pub r_addr: Vec, - pub read_check_claim_sum: K, - pub write_check_claim_sum: K, -} - -#[derive(Clone, Debug)] -pub struct TwistTimeLaneOpeningsLane { - pub wa_bits: Vec, - pub has_write: K, - pub inc_at_write_addr: K, -} - -#[derive(Clone, Debug)] -pub struct TwistTimeLaneOpenings { - pub lanes: Vec, -} - -#[derive(Clone, Debug)] -pub struct RouteAMemoryVerifyOutput { - pub claim_idx_end: usize, - pub twist_time_openings: Vec, -} - -pub(crate) fn prove_twist_addr_pre_time( - tr: &mut Poseidon2Transcript, - params: &NeoParams, - step: &StepWitnessBundle, - cpu_bus: Option<&BusLayout>, - ell_n: usize, - r_cycle: &[K], -) -> Result, PiCcsError> { - if step.mem_instances.is_empty() { - return Ok(Vec::new()); - } - let mut out = Vec::with_capacity(step.mem_instances.len()); - - let bus = - cpu_bus.ok_or_else(|| PiCcsError::InvalidInput("prove_twist_addr_pre_time requires shared_cpu_bus".into()))?; - let cpu_z_k = crate::memory_sidecar::cpu_bus::decode_cpu_z_to_k(params, &step.mcs.1.Z); - if bus.shout_cols.len() != step.lut_instances.len() || bus.twist_cols.len() != step.mem_instances.len() { - return Err(PiCcsError::InvalidInput( - "shared_cpu_bus layout mismatch for step (instance counts)".into(), - )); - } - - for (idx, (mem_inst, _mem_wit)) in step.mem_instances.iter().enumerate() { - neo_memory::addr::validate_twist_bit_addressing(mem_inst)?; - let pow2_cycle = 1usize << ell_n; - if mem_inst.steps > pow2_cycle { - return Err(PiCcsError::InvalidInput(format!( - "Twist(Route A): steps={} exceeds 2^ell_cycle={pow2_cycle}", - mem_inst.steps - ))); - } - - let z = &cpu_z_k; - - let ell_addr = mem_inst.d * mem_inst.ell; - let expected_lanes = mem_inst.lanes.max(1); - let twist_inst_cols = bus.twist_cols.get(idx).ok_or_else(|| { - PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch: missing twist_cols for mem_idx={idx}" - )) - })?; - if twist_inst_cols.lanes.len() != expected_lanes { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at mem_idx={idx}: expected lanes={expected_lanes}, got {}", - twist_inst_cols.lanes.len() - ))); - } - - let mut lanes: Vec = Vec::with_capacity(twist_inst_cols.lanes.len()); - for (lane_idx, twist_cols) in twist_inst_cols.lanes.iter().enumerate() { - if twist_cols.ra_bits.end - twist_cols.ra_bits.start != ell_addr - || twist_cols.wa_bits.end - twist_cols.wa_bits.start != ell_addr - { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at mem_idx={idx}, lane={lane_idx}: expected ell_addr={ell_addr}" - ))); - } - - let mut ra_bits = Vec::with_capacity(ell_addr); - for col_id in twist_cols.ra_bits.clone() { - ra_bits.push(crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - col_id, - mem_inst.steps, - pow2_cycle, - )?); - } - - let mut wa_bits = Vec::with_capacity(ell_addr); - for col_id in twist_cols.wa_bits.clone() { - wa_bits.push(crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - col_id, - mem_inst.steps, - pow2_cycle, - )?); - } - - let has_read = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - twist_cols.has_read, - mem_inst.steps, - pow2_cycle, - )?; - let has_write = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - twist_cols.has_write, - mem_inst.steps, - pow2_cycle, - )?; - let wv = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - twist_cols.wv, - mem_inst.steps, - pow2_cycle, - )?; - let rv = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - twist_cols.rv, - mem_inst.steps, - pow2_cycle, - )?; - let inc_at_write_addr = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - twist_cols.inc, - mem_inst.steps, - pow2_cycle, - )?; - - lanes.push(TwistLaneSparseCols { - ra_bits, - wa_bits, - has_read, - has_write, - wv, - rv, - inc_at_write_addr, - }); - } - - let decoded = TwistDecodedColsSparse { lanes }; - - let init_sparse: Vec<(usize, K)> = match &mem_inst.init { - MemInit::Zero => Vec::new(), - MemInit::Sparse(pairs) => pairs - .iter() - .map(|(addr, val)| { - let addr_usize = usize::try_from(*addr).map_err(|_| { - PiCcsError::InvalidInput(format!("Twist: init address doesn't fit usize: addr={addr}")) - })?; - if addr_usize >= mem_inst.k { - return Err(PiCcsError::InvalidInput(format!( - "Twist: init address out of range: addr={addr} >= k={}", - mem_inst.k - ))); - } - Ok((addr_usize, (*val).into())) - }) - .collect::>()?, - }; - - let mut read_addr_oracle = - TwistReadCheckAddrOracleSparseTimeMultiLane::new(init_sparse.clone(), r_cycle, &decoded.lanes); - let mut write_addr_oracle = - TwistWriteCheckAddrOracleSparseTimeMultiLane::new(init_sparse, r_cycle, &decoded.lanes); - - let labels: [&[u8]; 2] = [b"twist/read_addr_pre".as_slice(), b"twist/write_addr_pre".as_slice()]; - let claimed_sums = vec![K::ZERO, K::ZERO]; - tr.append_message(b"twist/addr_pre_time/claim_idx", &(idx as u64).to_le_bytes()); - bind_batched_claim_sums(tr, b"twist/addr_pre_time/claimed_sums", &claimed_sums, &labels); - - let mut claims = [ - BatchedClaim { - oracle: &mut read_addr_oracle, - claimed_sum: K::ZERO, - label: labels[0], - }, - BatchedClaim { - oracle: &mut write_addr_oracle, - claimed_sum: K::ZERO, - label: labels[1], - }, - ]; - - let (r_addr, per_claim_results) = run_batched_sumcheck_prover_ds(tr, b"twist/addr_pre_time", idx, &mut claims)?; - if per_claim_results.len() != 2 { - return Err(PiCcsError::ProtocolError(format!( - "twist addr-pre per-claim results len()={}, expected 2", - per_claim_results.len() - ))); - } - - out.push(TwistAddrPreProverData { - addr_pre: BatchedAddrProof { - claimed_sums, - round_polys: vec![ - per_claim_results[0].round_polys.clone(), - per_claim_results[1].round_polys.clone(), - ], - r_addr: r_addr.clone(), - }, - decoded, - read_check_claim_sum: per_claim_results[0].final_value, - write_check_claim_sum: per_claim_results[1].final_value, - }); - } - - Ok(out) -} - -pub(crate) fn prove_shout_addr_pre_time( - tr: &mut Poseidon2Transcript, - params: &NeoParams, - step: &StepWitnessBundle, - cpu_bus: Option<&BusLayout>, - ell_n: usize, - r_cycle: &[K], - step_idx: usize, -) -> Result { - if step.lut_instances.is_empty() { - return Ok(ShoutAddrPreBatchProverData { - addr_pre: ShoutAddrPreProof::default(), - decoded: Vec::new(), - }); - } - - let bus = - cpu_bus.ok_or_else(|| PiCcsError::InvalidInput("prove_shout_addr_pre_time requires shared_cpu_bus".into()))?; - let cpu_z_k = crate::memory_sidecar::cpu_bus::decode_cpu_z_to_k(params, &step.mcs.1.Z); - if bus.shout_cols.len() != step.lut_instances.len() || bus.twist_cols.len() != step.mem_instances.len() { - return Err(PiCcsError::InvalidInput( - "shared_cpu_bus layout mismatch for step (instance counts)".into(), - )); - } - - let pow2_cycle = 1usize << ell_n; - let n_lut = step.lut_instances.len(); - let total_lanes: usize = step - .lut_instances - .iter() - .map(|(inst, _)| inst.lanes.max(1)) - .sum(); - - let mut decoded_cols: Vec = Vec::with_capacity(n_lut); - let mut claimed_sums: Vec = vec![K::ZERO; total_lanes]; - - struct AddrPreGroupBuilder { - active_lanes: Vec, - active_claimed_sums: Vec, - addr_oracles: Vec>, - } - - // Group Shout addr-pre claims by `ell_addr` so we can run one batched sumcheck per group. - let mut groups: std::collections::BTreeMap = std::collections::BTreeMap::new(); - - let mut flat_lane_idx: usize = 0; - for (idx, (lut_inst, _lut_wit)) in step.lut_instances.iter().enumerate() { - neo_memory::addr::validate_shout_bit_addressing(lut_inst)?; - if lut_inst.steps > pow2_cycle { - return Err(PiCcsError::InvalidInput(format!( - "Shout(Route A): steps={} exceeds 2^ell_cycle={pow2_cycle}", - lut_inst.steps - ))); - } - - let z = &cpu_z_k; - let inst_ell_addr = lut_inst.d * lut_inst.ell; - let inst_ell_addr_u32 = u32::try_from(inst_ell_addr) - .map_err(|_| PiCcsError::InvalidInput("Shout(Route A): ell_addr overflows u32".into()))?; - groups - .entry(inst_ell_addr_u32) - .or_insert_with(|| AddrPreGroupBuilder { - active_lanes: Vec::new(), - active_claimed_sums: Vec::new(), - addr_oracles: Vec::new(), - }); - let inst_cols = bus.shout_cols.get(idx).ok_or_else(|| { - PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch: missing shout_cols for lut_idx={idx}" - )) - })?; - let expected_lanes = lut_inst.lanes.max(1); - if inst_cols.lanes.len() != expected_lanes { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at lut_idx={idx}: shout lanes={} but instance expects {}", - inst_cols.lanes.len(), - expected_lanes - ))); - } - - let mut lanes: Vec = Vec::with_capacity(expected_lanes); - - for (lane_idx, shout_cols) in inst_cols.lanes.iter().enumerate() { - if shout_cols.addr_bits.end - shout_cols.addr_bits.start != inst_ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at lut_idx={idx}, lane_idx={lane_idx}: expected ell_addr={inst_ell_addr}" - ))); - } - - let mut addr_bits = Vec::with_capacity(inst_ell_addr); - for col_id in shout_cols.addr_bits.clone() { - addr_bits.push(crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - col_id, - lut_inst.steps, - pow2_cycle, - )?); - } - - let has_lookup = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - shout_cols.has_lookup, - lut_inst.steps, - pow2_cycle, - )?; - let val = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( - z, - bus, - shout_cols.val, - lut_inst.steps, - pow2_cycle, - )?; - - let has_any_lookup = has_lookup - .entries() - .iter() - .any(|&(_t, gate)| gate != K::ZERO); - - if has_any_lookup { - let (addr_oracle, lane_sum): (Box, K) = match &lut_inst.table_spec { - None => { - let table_k: Vec = lut_inst.table.iter().map(|&v| v.into()).collect(); - let (o, sum) = - AddressLookupOracle::new(&addr_bits, &has_lookup, &table_k, r_cycle, inst_ell_addr); - (Box::new(o), sum) - } - Some(LutTableSpec::RiscvOpcode { opcode, xlen }) => { - let (o, sum) = RiscvAddressLookupOracleSparse::new_sparse_time( - *opcode, - *xlen, - &addr_bits, - &has_lookup, - r_cycle, - )?; - (Box::new(o), sum) - } - Some(LutTableSpec::IdentityU32) => { - let (o, sum) = IdentityAddressLookupOracleSparse::new_sparse_time( - inst_ell_addr, - &addr_bits, - &has_lookup, - r_cycle, - )?; - (Box::new(o), sum) - } - }; - - claimed_sums[flat_lane_idx] = lane_sum; - let lane_idx_u32 = u32::try_from(flat_lane_idx) - .map_err(|_| PiCcsError::InvalidInput("Shout(Route A): lane index overflow".into()))?; - let group = groups - .get_mut(&inst_ell_addr_u32) - .ok_or_else(|| PiCcsError::ProtocolError("Shout(Route A): missing ell_addr group".into()))?; - group.active_lanes.push(lane_idx_u32); - group.active_claimed_sums.push(lane_sum); - group.addr_oracles.push(addr_oracle); - } - - lanes.push(ShoutLaneSparseCols { - addr_bits, - has_lookup, - val, - }); - flat_lane_idx += 1; - } - - let decoded = ShoutDecodedColsSparse { lanes }; - - decoded_cols.push(decoded); - } - if flat_lane_idx != total_lanes { - return Err(PiCcsError::ProtocolError(format!( - "Shout(Route A): flat lane indexing drift (got {flat_lane_idx}, expected {total_lanes})" - ))); - } - - let labels_all: Vec<&'static [u8]> = vec![b"shout/addr_pre".as_slice(); total_lanes]; - tr.append_message(b"shout/addr_pre_time/step_idx", &(step_idx as u64).to_le_bytes()); - bind_batched_claim_sums(tr, b"shout/addr_pre_time/claimed_sums", &claimed_sums, &labels_all); - - let mut group_proofs: Vec> = Vec::with_capacity(groups.len()); - for (group_idx, (ell_addr, mut group)) in groups.into_iter().enumerate() { - tr.append_message(b"shout/addr_pre_time/group_idx", &(group_idx as u64).to_le_bytes()); - tr.append_message(b"shout/addr_pre_time/group_ell_addr", &(ell_addr as u64).to_le_bytes()); - - let (r_addr, round_polys) = if group.active_lanes.is_empty() { - // No active lanes in this `ell_addr` group; sample an arbitrary `r_addr` without running sumcheck. - tr.append_message(b"shout/addr_pre_time/no_sumcheck", &(step_idx as u64).to_le_bytes()); - tr.append_message( - b"shout/addr_pre_time/no_sumcheck/ell_addr", - &(ell_addr as u64).to_le_bytes(), - ); - ( - ts::sample_ext_point( - tr, - b"shout/addr_pre_time/no_sumcheck/r_addr", - b"shout/addr_pre_time/no_sumcheck/r_addr/0", - b"shout/addr_pre_time/no_sumcheck/r_addr/1", - ell_addr as usize, - ), - Vec::new(), - ) - } else { - let labels_active: Vec<&'static [u8]> = vec![b"shout/addr_pre".as_slice(); group.addr_oracles.len()]; - let mut claims: Vec> = group - .addr_oracles - .iter_mut() - .zip(group.active_claimed_sums.iter()) - .zip(labels_active.iter()) - .map(|((oracle, sum), label)| BatchedClaim { - oracle: oracle.as_mut(), - claimed_sum: *sum, - label: *label, - }) - .collect(); - - let (r_addr, per_claim_results) = - run_batched_sumcheck_prover_ds(tr, b"shout/addr_pre_time", step_idx, claims.as_mut_slice())?; - let round_polys = per_claim_results - .iter() - .map(|r| r.round_polys.clone()) - .collect::>(); - (r_addr, round_polys) - }; - - group_proofs.push(ShoutAddrPreGroupProof { - ell_addr, - active_lanes: group.active_lanes, - round_polys, - r_addr, - }); - } - - Ok(ShoutAddrPreBatchProverData { - addr_pre: ShoutAddrPreProof { - claimed_sums, - groups: group_proofs, - }, - decoded: decoded_cols, - }) -} - -pub fn verify_shout_addr_pre_time( - tr: &mut Poseidon2Transcript, - step: &StepInstanceBundle, - mem_proof: &MemSidecarProof, - step_idx: usize, -) -> Result, PiCcsError> { - let proof = &mem_proof.shout_addr_pre; - - if step.lut_insts.is_empty() { - if !proof.claimed_sums.is_empty() || !proof.groups.is_empty() { - return Err(PiCcsError::InvalidInput( - "shout_addr_pre must be empty when there are no Shout instances".into(), - )); - } - return Ok(Vec::new()); - } - - let total_lanes: usize = step.lut_insts.iter().map(|inst| inst.lanes.max(1)).sum(); - if proof.claimed_sums.len() != total_lanes { - return Err(PiCcsError::InvalidInput(format!( - "shout_addr_pre claimed_sums.len()={}, expected total_lanes={}", - proof.claimed_sums.len(), - total_lanes - ))); - } - - // Flatten lane->ell_addr mapping in canonical order so we can validate group membership and - // attach the correct `r_addr` per lane. - let mut lane_ell_addr: Vec = Vec::with_capacity(total_lanes); - let mut required_ell_addrs: std::collections::BTreeSet = std::collections::BTreeSet::new(); - for lut_inst in step.lut_insts.iter() { - neo_memory::addr::validate_shout_bit_addressing(lut_inst)?; - let inst_ell_addr = lut_inst.d * lut_inst.ell; - let inst_ell_addr_u32 = u32::try_from(inst_ell_addr) - .map_err(|_| PiCcsError::InvalidInput("Shout: ell_addr overflows u32".into()))?; - required_ell_addrs.insert(inst_ell_addr_u32); - for _lane_idx in 0..lut_inst.lanes.max(1) { - lane_ell_addr.push(inst_ell_addr_u32); - } - } - if lane_ell_addr.len() != total_lanes { - return Err(PiCcsError::ProtocolError( - "shout addr-pre lane indexing drift (lane_ell_addr)".into(), - )); - } - - // Groups must match the step's required `ell_addr` set and be sorted/unique. - if proof.groups.len() != required_ell_addrs.len() { - return Err(PiCcsError::InvalidInput(format!( - "shout_addr_pre groups.len()={}, expected {} (distinct ell_addr values in step)", - proof.groups.len(), - required_ell_addrs.len() - ))); - } - let required_list: Vec = required_ell_addrs.into_iter().collect(); - for (idx, group) in proof.groups.iter().enumerate() { - let expected_ell_addr = required_list[idx]; - if group.ell_addr != expected_ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "shout_addr_pre groups not sorted or mismatched: groups[{idx}].ell_addr={} but expected {expected_ell_addr}", - group.ell_addr - ))); - } - if group.r_addr.len() != group.ell_addr as usize { - return Err(PiCcsError::InvalidInput(format!( - "shout_addr_pre group ell_addr={} has r_addr.len()={}, expected {}", - group.ell_addr, - group.r_addr.len(), - group.ell_addr - ))); - } - if group.round_polys.len() != group.active_lanes.len() { - return Err(PiCcsError::InvalidInput(format!( - "shout_addr_pre group ell_addr={} round_polys.len()={}, expected active_lanes.len()={}", - group.ell_addr, - group.round_polys.len(), - group.active_lanes.len() - ))); - } - - for (pos, &lane_idx) in group.active_lanes.iter().enumerate() { - let lane_idx_usize = lane_idx as usize; - if lane_idx_usize >= total_lanes { - return Err(PiCcsError::InvalidInput( - "shout_addr_pre active_lanes has index out of range".into(), - )); - } - if lane_ell_addr[lane_idx_usize] != group.ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "shout_addr_pre active_lanes contains lane_idx={} with ell_addr={}, but group ell_addr={}", - lane_idx, lane_ell_addr[lane_idx_usize], group.ell_addr - ))); - } - if pos > 0 && group.active_lanes[pos - 1] >= lane_idx { - return Err(PiCcsError::InvalidInput( - "shout_addr_pre active_lanes must be strictly increasing".into(), - )); - } - } - for (pos, rounds) in group.round_polys.iter().enumerate() { - if rounds.len() != group.ell_addr as usize { - return Err(PiCcsError::InvalidInput(format!( - "shout_addr_pre group ell_addr={} round_polys[{pos}].len()={}, expected {}", - group.ell_addr, - rounds.len(), - group.ell_addr - ))); - } - } - } - - // Bind all claimed sums (all lanes) once. - let labels_all: Vec<&'static [u8]> = vec![b"shout/addr_pre".as_slice(); total_lanes]; - tr.append_message(b"shout/addr_pre_time/step_idx", &(step_idx as u64).to_le_bytes()); - bind_batched_claim_sums( - tr, - b"shout/addr_pre_time/claimed_sums", - &proof.claimed_sums, - &labels_all, - ); - - // Verify each `ell_addr` group independently, collecting per-lane addr-pre finals and - // recording the shared `r_addr` for that group. - let mut lane_is_active = vec![false; total_lanes]; - let mut lane_addr_final = vec![K::ZERO; total_lanes]; - let mut r_addr_by_ell: std::collections::BTreeMap> = std::collections::BTreeMap::new(); - let mut seen_active: std::collections::HashSet = std::collections::HashSet::new(); - - for (group_idx, group) in proof.groups.iter().enumerate() { - tr.append_message(b"shout/addr_pre_time/group_idx", &(group_idx as u64).to_le_bytes()); - tr.append_message( - b"shout/addr_pre_time/group_ell_addr", - &(group.ell_addr as u64).to_le_bytes(), - ); - - if group.active_lanes.is_empty() { - // No active lanes in this group: match prover's deterministic fallback sampling. - tr.append_message(b"shout/addr_pre_time/no_sumcheck", &(step_idx as u64).to_le_bytes()); - tr.append_message( - b"shout/addr_pre_time/no_sumcheck/ell_addr", - &(group.ell_addr as u64).to_le_bytes(), - ); - let r_addr = ts::sample_ext_point( - tr, - b"shout/addr_pre_time/no_sumcheck/r_addr", - b"shout/addr_pre_time/no_sumcheck/r_addr/0", - b"shout/addr_pre_time/no_sumcheck/r_addr/1", - group.ell_addr as usize, - ); - if r_addr != group.r_addr { - return Err(PiCcsError::ProtocolError( - "shout_addr_pre r_addr mismatch: transcript-derived vs proof".into(), - )); - } - r_addr_by_ell.insert(group.ell_addr, r_addr); - continue; - } - - let active_count = group.active_lanes.len(); - let mut active_claimed_sums: Vec = Vec::with_capacity(active_count); - for &lane_idx in group.active_lanes.iter() { - if !seen_active.insert(lane_idx) { - return Err(PiCcsError::InvalidInput( - "shout_addr_pre active_lanes contains duplicates across groups".into(), - )); - } - active_claimed_sums.push( - *proof - .claimed_sums - .get(lane_idx as usize) - .ok_or_else(|| PiCcsError::ProtocolError("shout addr-pre active lane idx drift".into()))?, - ); - } - let labels_active: Vec<&'static [u8]> = vec![b"shout/addr_pre".as_slice(); active_count]; - let degree_bounds = vec![2usize; active_count]; - let (r_addr, finals, ok) = verify_batched_sumcheck_rounds_ds( - tr, - b"shout/addr_pre_time", - step_idx, - &group.round_polys, - &active_claimed_sums, - &labels_active, - °ree_bounds, - ); - if !ok { - return Err(PiCcsError::SumcheckError( - "shout addr-pre batched sumcheck invalid".into(), - )); - } - if r_addr != group.r_addr { - return Err(PiCcsError::ProtocolError( - "shout_addr_pre r_addr mismatch: transcript-derived vs proof".into(), - )); - } - if finals.len() != active_count { - return Err(PiCcsError::ProtocolError(format!( - "shout addr-pre finals.len()={}, expected active_count={active_count}", - finals.len() - ))); - } - - for (pos, &lane_idx) in group.active_lanes.iter().enumerate() { - let lane_idx_usize = lane_idx as usize; - lane_is_active[lane_idx_usize] = true; - lane_addr_final[lane_idx_usize] = finals[pos]; - } - r_addr_by_ell.insert(group.ell_addr, r_addr); - } - - // Build per-lane verify data in canonical order. - let mut out = Vec::with_capacity(total_lanes); - for (lut_inst, inst_ell_addr) in step.lut_insts.iter().map(|inst| (inst, inst.d * inst.ell)) { - let expected_lanes = lut_inst.lanes.max(1); - let inst_ell_addr_u32 = u32::try_from(inst_ell_addr) - .map_err(|_| PiCcsError::InvalidInput("Shout: ell_addr overflows u32".into()))?; - let r_addr = r_addr_by_ell - .get(&inst_ell_addr_u32) - .ok_or_else(|| PiCcsError::ProtocolError("missing shout addr-pre group r_addr".into()))?; - - for _lane_idx in 0..expected_lanes { - let flat_lane_idx = out.len(); - let addr_claim_sum = *proof - .claimed_sums - .get(flat_lane_idx) - .ok_or_else(|| PiCcsError::ProtocolError("shout addr-pre lane index drift".into()))?; - let is_active = *lane_is_active - .get(flat_lane_idx) - .ok_or_else(|| PiCcsError::ProtocolError("shout addr-pre lane idx drift".into()))?; - let addr_final = *lane_addr_final - .get(flat_lane_idx) - .ok_or_else(|| PiCcsError::ProtocolError("shout addr-pre lane idx drift".into()))?; - - let table_eval_at_r_addr = if is_active { - match &lut_inst.table_spec { - None => { - let pow2 = 1usize - .checked_shl(r_addr.len() as u32) - .ok_or_else(|| PiCcsError::InvalidInput("Shout: 2^ell_addr overflow".into()))?; - let mut acc = K::ZERO; - for (i, &v) in lut_inst.table.iter().enumerate().take(pow2) { - let w = neo_memory::mle::chi_at_index(r_addr, i); - acc += K::from(v) * w; - } - acc - } - Some(spec) => spec.eval_table_mle(r_addr)?, - } - } else { - K::ZERO - }; - - out.push(ShoutAddrPreVerifyData { - is_active, - addr_claim_sum, - addr_final: if is_active { addr_final } else { K::ZERO }, - r_addr: r_addr.clone(), - table_eval_at_r_addr, - }); - } - } - if out.len() != total_lanes { - return Err(PiCcsError::ProtocolError("shout addr-pre lane count mismatch".into())); - } - - Ok(out) -} - -pub fn verify_twist_addr_pre_time( - tr: &mut Poseidon2Transcript, - step: &StepInstanceBundle, - mem_proof: &MemSidecarProof, -) -> Result, PiCcsError> { - let mut out = Vec::with_capacity(step.mem_insts.len()); - let proof_offset = step.lut_insts.len(); - - for (idx, mem_inst) in step.mem_insts.iter().enumerate() { - let proof = match mem_proof.proofs.get(proof_offset + idx) { - Some(MemOrLutProof::Twist(p)) => p, - _ => return Err(PiCcsError::InvalidInput("expected Twist proof".into())), - }; - - if proof.addr_pre.claimed_sums.len() != 2 { - return Err(PiCcsError::InvalidInput(format!( - "twist addr_pre claimed_sums.len()={}, expected 2", - proof.addr_pre.claimed_sums.len() - ))); - } - if proof.addr_pre.round_polys.len() != 2 { - return Err(PiCcsError::InvalidInput(format!( - "twist addr_pre round_polys.len()={}, expected 2", - proof.addr_pre.round_polys.len() - ))); - } - if proof.addr_pre.claimed_sums[0] != K::ZERO || proof.addr_pre.claimed_sums[1] != K::ZERO { - return Err(PiCcsError::ProtocolError( - "twist addr_pre claimed_sums mismatch (expected both 0)".into(), - )); - } - - let labels: [&[u8]; 2] = [b"twist/read_addr_pre".as_slice(), b"twist/write_addr_pre".as_slice()]; - let degree_bounds = vec![2usize, 2usize]; - tr.append_message(b"twist/addr_pre_time/claim_idx", &(idx as u64).to_le_bytes()); - bind_batched_claim_sums( - tr, - b"twist/addr_pre_time/claimed_sums", - &proof.addr_pre.claimed_sums, - &labels, - ); - - let (r_addr, finals, ok) = verify_batched_sumcheck_rounds_ds( - tr, - b"twist/addr_pre_time", - idx, - &proof.addr_pre.round_polys, - &proof.addr_pre.claimed_sums, - &labels, - °ree_bounds, - ); - if !ok { - return Err(PiCcsError::SumcheckError( - "twist addr-pre batched sumcheck invalid".into(), - )); - } - if r_addr != proof.addr_pre.r_addr { - return Err(PiCcsError::ProtocolError( - "twist addr_pre r_addr mismatch: transcript-derived vs proof".into(), - )); - } - - let ell_addr = mem_inst.d * mem_inst.ell; - if r_addr.len() != ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "twist addr_pre r_addr.len()={}, expected ell_addr={}", - r_addr.len(), - ell_addr - ))); - } - if finals.len() != 2 { - return Err(PiCcsError::ProtocolError(format!( - "twist addr-pre finals.len()={}, expected 2", - finals.len() - ))); - } - - out.push(TwistAddrPreVerifyData { - r_addr, - read_check_claim_sum: finals[0], - write_check_claim_sum: finals[1], - }); - } - - Ok(out) -} - -pub(crate) fn build_route_a_memory_oracles( - _params: &NeoParams, - step: &StepWitnessBundle, - _ell_n: usize, - r_cycle: &[K], - shout_pre: &ShoutAddrPreBatchProverData, - twist_pre: &[TwistAddrPreProverData], -) -> Result { - if shout_pre.decoded.len() != step.lut_instances.len() { - return Err(PiCcsError::InvalidInput(format!( - "shout pre-time count mismatch (expected {}, got {})", - step.lut_instances.len(), - shout_pre.decoded.len() - ))); - } - if twist_pre.len() != step.mem_instances.len() { - return Err(PiCcsError::InvalidInput(format!( - "twist pre-time decoded count mismatch (expected {}, got {})", - step.mem_instances.len(), - twist_pre.len() - ))); - } - - let mut shout_oracles = Vec::with_capacity(step.lut_instances.len()); - let mut r_addr_by_ell: std::collections::BTreeMap = std::collections::BTreeMap::new(); - for g in shout_pre.addr_pre.groups.iter() { - r_addr_by_ell.insert(g.ell_addr, g.r_addr.as_slice()); - } - for (lut_idx, ((lut_inst, _lut_wit), decoded)) in step - .lut_instances - .iter() - .zip(shout_pre.decoded.iter()) - .enumerate() - { - let ell_addr = lut_inst.d * lut_inst.ell; - let ell_addr_u32 = u32::try_from(ell_addr) - .map_err(|_| PiCcsError::InvalidInput("Shout(Route A): ell_addr overflows u32".into()))?; - let r_addr = *r_addr_by_ell - .get(&ell_addr_u32) - .ok_or_else(|| PiCcsError::ProtocolError("missing shout addr-pre group r_addr".into()))?; - if r_addr.len() != ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "Shout(Route A): r_addr.len()={} != ell_addr={}", - r_addr.len(), - ell_addr - ))); - } - - if decoded.lanes.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "Shout(Route A): decoded lanes empty at lut_idx={lut_idx}" - ))); - } - - let lane_count = decoded.lanes.len(); - let mut lanes: Vec = Vec::with_capacity(lane_count); - - for lane in decoded.lanes.iter() { - let (value_oracle, value_claim) = - ShoutValueOracleSparse::new(r_cycle, lane.has_lookup.clone(), lane.val.clone()); - - let (adapter_oracle, adapter_claim) = IndexAdapterOracleSparseTime::new_with_gate( - r_cycle, - lane.has_lookup.clone(), - lane.addr_bits.clone(), - r_addr, - ); - - lanes.push(RouteAShoutTimeLaneOracles { - value: Box::new(value_oracle), - value_claim, - adapter: Box::new(adapter_oracle), - adapter_claim, - }); - } - - let mut bit_cols: Vec> = Vec::with_capacity(lane_count * (ell_addr + 1)); - for lane in decoded.lanes.iter() { - bit_cols.extend(lane.addr_bits.iter().cloned()); - bit_cols.push(lane.has_lookup.clone()); - } - let weights = bitness_weights(r_cycle, bit_cols.len(), 0x5348_4F55_54u64 + lut_idx as u64); - let bitness_oracle = LazyWeightedBitnessOracleSparseTime::new_with_cycle(r_cycle, bit_cols, weights); - let bitness: Vec> = vec![Box::new(bitness_oracle)]; - - shout_oracles.push(RouteAShoutTimeOracles { - lanes, - bitness, - ell_addr, - }); - } - - let mut twist_oracles = Vec::with_capacity(step.mem_instances.len()); - for (mem_idx, ((mem_inst, _mem_wit), pre)) in step.mem_instances.iter().zip(twist_pre.iter()).enumerate() { - let init_at_r_addr = eval_init_at_r_addr(&mem_inst.init, mem_inst.k, &pre.addr_pre.r_addr)?; - let ell_addr = mem_inst.d * mem_inst.ell; - if pre.addr_pre.r_addr.len() != ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "Twist(Route A): r_addr.len()={} != ell_addr={}", - pre.addr_pre.r_addr.len(), - ell_addr - ))); - } - - if pre.decoded.lanes.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "Twist(Route A): decoded lanes empty at mem_idx={mem_idx}" - ))); - } - - let inc_terms_at_r_addr = build_twist_inc_terms_at_r_addr(&pre.decoded.lanes, &pre.addr_pre.r_addr); - - let mut read_oracles: Vec> = Vec::with_capacity(pre.decoded.lanes.len()); - let mut write_oracles: Vec> = Vec::with_capacity(pre.decoded.lanes.len()); - for lane in pre.decoded.lanes.iter() { - read_oracles.push(Box::new(TwistReadCheckOracleSparseTime::new_with_inc_terms( - r_cycle, - lane.has_read.clone(), - lane.rv.clone(), - lane.ra_bits.clone(), - &pre.addr_pre.r_addr, - init_at_r_addr, - inc_terms_at_r_addr.clone(), - ))); - write_oracles.push(Box::new(TwistWriteCheckOracleSparseTime::new_with_inc_terms( - r_cycle, - lane.has_write.clone(), - lane.wv.clone(), - lane.inc_at_write_addr.clone(), - lane.wa_bits.clone(), - &pre.addr_pre.r_addr, - init_at_r_addr, - inc_terms_at_r_addr.clone(), - ))); - } - let read_check: Box = Box::new(SumRoundOracle::new(read_oracles)); - let write_check: Box = Box::new(SumRoundOracle::new(write_oracles)); - - let lane_count = pre.decoded.lanes.len(); - let mut bit_cols: Vec> = Vec::with_capacity(lane_count * (2 * ell_addr + 2)); - for lane in pre.decoded.lanes.iter() { - bit_cols.extend(lane.ra_bits.iter().cloned()); - bit_cols.extend(lane.wa_bits.iter().cloned()); - bit_cols.push(lane.has_read.clone()); - bit_cols.push(lane.has_write.clone()); - } - let weights = bitness_weights(r_cycle, bit_cols.len(), 0x5457_4953_54u64 + mem_idx as u64); - let bitness_oracle = LazyWeightedBitnessOracleSparseTime::new_with_cycle(r_cycle, bit_cols, weights); - let bitness: Vec> = vec![Box::new(bitness_oracle)]; - - twist_oracles.push(RouteATwistTimeOracles { - read_check, - write_check, - bitness, - ell_addr, - }); - } - - Ok(RouteAMemoryOracles { - shout: shout_oracles, - twist: twist_oracles, - }) -} - -pub struct RouteAShoutTimeClaimsGuard<'a> { - pub lane_ranges: Vec>, - pub lanes: Vec>, - pub bitness: Vec>>, -} - -pub struct RouteAShoutTimeLaneClaims<'a> { - pub value_prefix: RoundOraclePrefix<'a>, - pub adapter_prefix: RoundOraclePrefix<'a>, - pub value_claim: K, - pub adapter_claim: K, -} - -pub fn build_route_a_shout_time_claims_guard<'a>( - shout_oracles: &'a mut [RouteAShoutTimeOracles], - ell_n: usize, -) -> RouteAShoutTimeClaimsGuard<'a> { - let mut lane_ranges: Vec> = Vec::with_capacity(shout_oracles.len()); - let mut lanes: Vec> = Vec::new(); - let mut bitness: Vec>> = Vec::with_capacity(shout_oracles.len()); - - for o in shout_oracles.iter_mut() { - bitness.push(core::mem::take(&mut o.bitness)); - let start = lanes.len(); - for lane in o.lanes.iter_mut() { - lanes.push(RouteAShoutTimeLaneClaims { - value_prefix: RoundOraclePrefix::new(lane.value.as_mut(), ell_n), - adapter_prefix: RoundOraclePrefix::new(lane.adapter.as_mut(), ell_n), - value_claim: lane.value_claim, - adapter_claim: lane.adapter_claim, - }); - } - let end = lanes.len(); - lane_ranges.push(start..end); - } - - RouteAShoutTimeClaimsGuard { - lane_ranges, - lanes, - bitness, - } -} - -pub struct ShoutRouteAProtocol<'a> { - guard: RouteAShoutTimeClaimsGuard<'a>, -} - -impl<'a> ShoutRouteAProtocol<'a> { - pub fn new(shout_oracles: &'a mut [RouteAShoutTimeOracles], ell_n: usize) -> Self { - Self { - guard: build_route_a_shout_time_claims_guard(shout_oracles, ell_n), - } - } -} - -impl<'o> TimeBatchedClaims for ShoutRouteAProtocol<'o> { - fn append_time_claims<'a>( - &'a mut self, - _ell_n: usize, - claimed_sums: &mut Vec, - degree_bounds: &mut Vec, - labels: &mut Vec<&'static [u8]>, - claim_is_dynamic: &mut Vec, - claims: &mut Vec>, - ) { - append_route_a_shout_time_claims( - &mut self.guard, - claimed_sums, - degree_bounds, - labels, - claim_is_dynamic, - claims, - ); - } -} - -pub fn append_route_a_shout_time_claims<'a>( - guard: &'a mut RouteAShoutTimeClaimsGuard<'_>, - claimed_sums: &mut Vec, - degree_bounds: &mut Vec, - labels: &mut Vec<&'static [u8]>, - claim_is_dynamic: &mut Vec, - claims: &mut Vec>, -) { - if guard.lane_ranges.is_empty() { - return; - } - if guard.bitness.len() != guard.lane_ranges.len() { - panic!("shout bitness count mismatch"); - } - - let mut lane_ranges_iter = guard.lane_ranges.iter(); - let mut next_end = lane_ranges_iter.next().expect("non-empty").end; - let mut bitness_iter = guard.bitness.iter_mut(); - - for (lane_idx, lane) in guard.lanes.iter_mut().enumerate() { - claimed_sums.push(lane.value_claim); - degree_bounds.push(lane.value_prefix.degree_bound()); - labels.push(b"shout/value"); - claim_is_dynamic.push(true); - claims.push(BatchedClaim { - oracle: &mut lane.value_prefix, - claimed_sum: lane.value_claim, - label: b"shout/value", - }); - - claimed_sums.push(lane.adapter_claim); - degree_bounds.push(lane.adapter_prefix.degree_bound()); - labels.push(b"shout/adapter"); - claim_is_dynamic.push(true); - claims.push(BatchedClaim { - oracle: &mut lane.adapter_prefix, - claimed_sum: lane.adapter_claim, - label: b"shout/adapter", - }); - - if lane_idx + 1 == next_end { - let bitness_vec = bitness_iter.next().expect("shout bitness idx drift"); - for bit_oracle in bitness_vec.iter_mut() { - claimed_sums.push(K::ZERO); - degree_bounds.push(bit_oracle.degree_bound()); - labels.push(b"shout/bitness"); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle: bit_oracle.as_mut(), - claimed_sum: K::ZERO, - label: b"shout/bitness", - }); - } - - next_end = lane_ranges_iter.next().map(|r| r.end).unwrap_or(usize::MAX); - } - } - - if bitness_iter.next().is_some() { - panic!("shout bitness not fully consumed"); - } -} - -pub struct RouteATwistTimeClaimsGuard<'a> { - pub read_check_prefixes: Vec>, - pub write_check_prefixes: Vec>, - pub read_check_claims: Vec, - pub write_check_claims: Vec, - pub bitness: Vec>>, -} - -pub fn build_route_a_twist_time_claims_guard<'a>( - twist_oracles: &'a mut [RouteATwistTimeOracles], - ell_n: usize, - read_check_claims: Vec, - write_check_claims: Vec, -) -> RouteATwistTimeClaimsGuard<'a> { - let mut read_check_prefixes: Vec> = Vec::with_capacity(twist_oracles.len()); - let mut write_check_prefixes: Vec> = Vec::with_capacity(twist_oracles.len()); - let mut bitness: Vec>> = Vec::with_capacity(twist_oracles.len()); - - if read_check_claims.len() != twist_oracles.len() { - panic!( - "twist read-check claim count mismatch (claims={}, oracles={})", - read_check_claims.len(), - twist_oracles.len() - ); - } - if write_check_claims.len() != twist_oracles.len() { - panic!( - "twist write-check claim count mismatch (claims={}, oracles={})", - write_check_claims.len(), - twist_oracles.len() - ); - } - - for o in twist_oracles.iter_mut() { - bitness.push(core::mem::take(&mut o.bitness)); - read_check_prefixes.push(RoundOraclePrefix::new(o.read_check.as_mut(), ell_n)); - write_check_prefixes.push(RoundOraclePrefix::new(o.write_check.as_mut(), ell_n)); - } - - RouteATwistTimeClaimsGuard { - read_check_prefixes, - write_check_prefixes, - read_check_claims, - write_check_claims, - bitness, - } -} - -pub fn append_route_a_twist_time_claims<'a>( - guard: &'a mut RouteATwistTimeClaimsGuard<'_>, - claimed_sums: &mut Vec, - degree_bounds: &mut Vec, - labels: &mut Vec<&'static [u8]>, - claim_is_dynamic: &mut Vec, - claims: &mut Vec>, -) { - for (((read_check_time, write_check_time), bitness_vec), (read_claim, write_claim)) in guard - .read_check_prefixes - .iter_mut() - .zip(guard.write_check_prefixes.iter_mut()) - .zip(guard.bitness.iter_mut()) - .zip( - guard - .read_check_claims - .iter() - .zip(guard.write_check_claims.iter()), - ) - { - claimed_sums.push(*read_claim); - degree_bounds.push(read_check_time.degree_bound()); - labels.push(b"twist/read_check"); - claim_is_dynamic.push(true); - claims.push(BatchedClaim { - oracle: read_check_time, - claimed_sum: *read_claim, - label: b"twist/read_check", - }); - - claimed_sums.push(*write_claim); - degree_bounds.push(write_check_time.degree_bound()); - labels.push(b"twist/write_check"); - claim_is_dynamic.push(true); - claims.push(BatchedClaim { - oracle: write_check_time, - claimed_sum: *write_claim, - label: b"twist/write_check", - }); - - for bit_oracle in bitness_vec.iter_mut() { - claimed_sums.push(K::ZERO); - degree_bounds.push(bit_oracle.degree_bound()); - labels.push(b"twist/bitness"); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle: bit_oracle.as_mut(), - claimed_sum: K::ZERO, - label: b"twist/bitness", - }); - } - } -} - -pub struct TwistRouteAProtocol<'a> { - guard: RouteATwistTimeClaimsGuard<'a>, -} - -impl<'a> TwistRouteAProtocol<'a> { - pub fn new( - twist_oracles: &'a mut [RouteATwistTimeOracles], - ell_n: usize, - read_check_claims: Vec, - write_check_claims: Vec, - ) -> Self { - Self { - guard: build_route_a_twist_time_claims_guard(twist_oracles, ell_n, read_check_claims, write_check_claims), - } - } -} - -impl<'o> TimeBatchedClaims for TwistRouteAProtocol<'o> { - fn append_time_claims<'a>( - &'a mut self, - _ell_n: usize, - claimed_sums: &mut Vec, - degree_bounds: &mut Vec, - labels: &mut Vec<&'static [u8]>, - claim_is_dynamic: &mut Vec, - claims: &mut Vec>, - ) { - append_route_a_twist_time_claims( - &mut self.guard, - claimed_sums, - degree_bounds, - labels, - claim_is_dynamic, - claims, - ); - } -} - -pub(crate) fn finalize_route_a_memory_prover( - tr: &mut Poseidon2Transcript, - params: &NeoParams, - cpu_bus: &BusLayout, - s: &CcsStructure, - step: &StepWitnessBundle, - prev_step: Option<&StepWitnessBundle>, - prev_twist_decoded: Option<&[TwistDecodedColsSparse]>, - oracles: &mut RouteAMemoryOracles, - shout_addr_pre: &ShoutAddrPreProof, - twist_pre: &[TwistAddrPreProverData], - r_time: &[K], - m_in: usize, - step_idx: usize, -) -> Result, PiCcsError> { - let has_prev = prev_step.is_some(); - if has_prev != prev_twist_decoded.is_some() { - return Err(PiCcsError::InvalidInput(format!( - "Twist rollover decoded cache mismatch: prev_step.is_some()={} but prev_twist_decoded.is_some()={}", - has_prev, - prev_twist_decoded.is_some() - ))); - } - let total_lanes: usize = step - .lut_instances - .iter() - .map(|(inst, _)| inst.lanes.max(1)) - .sum(); - if shout_addr_pre.claimed_sums.len() != total_lanes { - return Err(PiCcsError::InvalidInput(format!( - "shout addr-pre proof count mismatch (expected claimed_sums.len()=total_lanes={}, got {})", - total_lanes, - shout_addr_pre.claimed_sums.len(), - ))); - } - { - let mut lane_ell_addr: Vec = Vec::with_capacity(total_lanes); - let mut required_ell_addrs: std::collections::BTreeSet = std::collections::BTreeSet::new(); - for (lut_inst, _lut_wit) in step.lut_instances.iter().map(|(inst, wit)| (inst, wit)) { - let inst_ell_addr = lut_inst.d * lut_inst.ell; - let inst_ell_addr_u32 = u32::try_from(inst_ell_addr) - .map_err(|_| PiCcsError::InvalidInput("Shout: ell_addr overflows u32".into()))?; - required_ell_addrs.insert(inst_ell_addr_u32); - for _lane_idx in 0..lut_inst.lanes.max(1) { - lane_ell_addr.push(inst_ell_addr_u32); - } - } - if lane_ell_addr.len() != total_lanes { - return Err(PiCcsError::ProtocolError( - "shout addr-pre lane indexing drift (lane_ell_addr)".into(), - )); - } - - if shout_addr_pre.groups.len() != required_ell_addrs.len() { - return Err(PiCcsError::InvalidInput(format!( - "shout addr-pre group count mismatch (expected {}, got {})", - required_ell_addrs.len(), - shout_addr_pre.groups.len() - ))); - } - let required_list: Vec = required_ell_addrs.into_iter().collect(); - let mut seen_active: std::collections::HashSet = std::collections::HashSet::new(); - for (idx, group) in shout_addr_pre.groups.iter().enumerate() { - let expected_ell_addr = required_list[idx]; - if group.ell_addr != expected_ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "shout addr-pre groups not sorted or mismatched: groups[{idx}].ell_addr={} but expected {expected_ell_addr}", - group.ell_addr - ))); - } - if group.r_addr.len() != group.ell_addr as usize { - return Err(PiCcsError::InvalidInput(format!( - "shout addr-pre group ell_addr={} has r_addr.len()={}, expected {}", - group.ell_addr, - group.r_addr.len(), - group.ell_addr - ))); - } - if group.round_polys.len() != group.active_lanes.len() { - return Err(PiCcsError::InvalidInput(format!( - "shout addr-pre group ell_addr={} round_polys.len()={}, expected active_lanes.len()={}", - group.ell_addr, - group.round_polys.len(), - group.active_lanes.len() - ))); - } - for (pos, &lane_idx) in group.active_lanes.iter().enumerate() { - let lane_idx_usize = lane_idx as usize; - if lane_idx_usize >= total_lanes { - return Err(PiCcsError::InvalidInput( - "shout addr-pre active_lanes has index out of range".into(), - )); - } - if lane_ell_addr[lane_idx_usize] != group.ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "shout addr-pre active_lanes contains lane_idx={} with ell_addr={}, but group ell_addr={}", - lane_idx, lane_ell_addr[lane_idx_usize], group.ell_addr - ))); - } - if pos > 0 && group.active_lanes[pos - 1] >= lane_idx { - return Err(PiCcsError::InvalidInput( - "shout addr-pre active_lanes must be strictly increasing".into(), - )); - } - if !seen_active.insert(lane_idx) { - return Err(PiCcsError::InvalidInput( - "shout addr-pre active_lanes contains duplicates across groups".into(), - )); - } - } - for (pos, rounds) in group.round_polys.iter().enumerate() { - if rounds.len() != group.ell_addr as usize { - return Err(PiCcsError::InvalidInput(format!( - "shout addr-pre group ell_addr={} round_polys[{pos}].len()={}, expected {}", - group.ell_addr, - rounds.len(), - group.ell_addr - ))); - } - } - } - } - if twist_pre.len() != step.mem_instances.len() { - return Err(PiCcsError::InvalidInput(format!( - "twist pre-time count mismatch (expected {}, got {})", - step.mem_instances.len(), - twist_pre.len() - ))); - } - if oracles.twist.len() != step.mem_instances.len() { - return Err(PiCcsError::InvalidInput(format!( - "twist oracle count mismatch (expected {}, got {})", - step.mem_instances.len(), - oracles.twist.len() - ))); - } - - for (idx, (lut_inst, lut_wit)) in step.lut_instances.iter().enumerate() { - if !lut_inst.comms.is_empty() || !lut_wit.mats.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "shared CPU bus requires metadata-only Shout instances (comms/mats must be empty, lut_idx={idx})" - ))); - } - } - for (idx, (mem_inst, mem_wit)) in step.mem_instances.iter().enumerate() { - if !mem_inst.comms.is_empty() || !mem_wit.mats.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "shared CPU bus requires metadata-only Twist instances (comms/mats must be empty, mem_idx={idx})" - ))); - } - } - if let Some(prev) = prev_step { - if prev.mem_instances.len() != step.mem_instances.len() { - return Err(PiCcsError::InvalidInput(format!( - "Twist rollover requires stable mem instance count: prev has {}, current has {}", - prev.mem_instances.len(), - step.mem_instances.len() - ))); - } - for (idx, (mem_inst, mem_wit)) in prev.mem_instances.iter().enumerate() { - if !mem_inst.comms.is_empty() || !mem_wit.mats.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "shared CPU bus requires metadata-only Twist instances (comms/mats must be empty, prev mem_idx={idx})" - ))); - } - } - } - let mut cpu_me_claims_val: Vec> = Vec::new(); - let mut proofs: Vec = Vec::new(); - - // -------------------------------------------------------------------- - // Phase 2: Twist val-eval sum-check (batched across mem instances). - // -------------------------------------------------------------------- - let mut twist_val_eval_proofs: Vec> = Vec::new(); - let mut r_val: Vec = Vec::new(); - if !step.mem_instances.is_empty() { - let plan = crate::memory_sidecar::claim_plan::TwistValEvalClaimPlan::build( - step.mem_instances.iter().map(|(inst, _)| inst), - has_prev, - ); - let n_mem = step.mem_instances.len(); - let claims_per_mem = plan.claims_per_mem; - let claim_count = plan.claim_count; - - let mut val_oracles: Vec> = Vec::with_capacity(claim_count); - let mut bind_claims: Vec<(u8, K)> = Vec::with_capacity(claim_count); - let mut claimed_sums: Vec = Vec::with_capacity(claim_count); - - let mut claimed_inc_sums_lt: Vec = Vec::with_capacity(n_mem); - let mut claimed_inc_sums_total: Vec = Vec::with_capacity(n_mem); - let mut claimed_prev_inc_sums_total: Vec> = Vec::with_capacity(n_mem); - - let mut claim_idx = 0usize; - for (i_mem, (mem_inst, _mem_wit)) in step.mem_instances.iter().enumerate() { - let pre = twist_pre - .get(i_mem) - .ok_or_else(|| PiCcsError::ProtocolError("missing Twist pre-time data".into()))?; - let decoded = &pre.decoded; - let r_addr = &pre.addr_pre.r_addr; - if decoded.lanes.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "Twist(Route A): decoded lanes empty at mem_idx={i_mem}" - ))); - } - - let mut lt_oracles: Vec> = Vec::with_capacity(decoded.lanes.len()); - let mut claimed_inc_sum_lt = K::ZERO; - for lane in decoded.lanes.iter() { - let (oracle, claim) = TwistValEvalOracleSparseTime::new( - lane.wa_bits.clone(), - lane.has_write.clone(), - lane.inc_at_write_addr.clone(), - r_addr, - r_time, - ); - lt_oracles.push(Box::new(oracle)); - claimed_inc_sum_lt += claim; - } - let oracle_lt: Box = Box::new(SumRoundOracle::new(lt_oracles)); - - let mut total_oracles: Vec> = Vec::with_capacity(decoded.lanes.len()); - let mut claimed_inc_sum_total = K::ZERO; - for lane in decoded.lanes.iter() { - let (oracle, claim) = TwistTotalIncOracleSparseTime::new( - lane.wa_bits.clone(), - lane.has_write.clone(), - lane.inc_at_write_addr.clone(), - r_addr, - ); - total_oracles.push(Box::new(oracle)); - claimed_inc_sum_total += claim; - } - let oracle_total: Box = Box::new(SumRoundOracle::new(total_oracles)); - - val_oracles.push(oracle_lt); - bind_claims.push((plan.bind_tags[claim_idx], claimed_inc_sum_lt)); - claimed_sums.push(claimed_inc_sum_lt); - claim_idx += 1; - - val_oracles.push(oracle_total); - bind_claims.push((plan.bind_tags[claim_idx], claimed_inc_sum_total)); - claimed_sums.push(claimed_inc_sum_total); - claim_idx += 1; - - claimed_inc_sums_lt.push(claimed_inc_sum_lt); - claimed_inc_sums_total.push(claimed_inc_sum_total); - - if let Some(prev) = prev_step { - let (prev_inst, _prev_wit) = prev - .mem_instances - .get(i_mem) - .ok_or_else(|| PiCcsError::ProtocolError("missing prev mem instance".into()))?; - if prev_inst.d != mem_inst.d - || prev_inst.ell != mem_inst.ell - || prev_inst.k != mem_inst.k - || prev_inst.lanes != mem_inst.lanes - { - return Err(PiCcsError::InvalidInput(format!( - "Twist rollover requires stable geometry at mem_idx={}: prev (k={}, d={}, ell={}, lanes={}) vs cur (k={}, d={}, ell={}, lanes={})", - i_mem, - prev_inst.k, - prev_inst.d, - prev_inst.ell, - prev_inst.lanes, - mem_inst.k, - mem_inst.d, - mem_inst.ell, - mem_inst.lanes - ))); - } - let prev_decoded = prev_twist_decoded - .ok_or_else(|| PiCcsError::ProtocolError("missing prev Twist decoded cols".into()))? - .get(i_mem) - .ok_or_else(|| PiCcsError::ProtocolError("missing prev Twist decoded cols at mem_idx".into()))?; - if prev_decoded.lanes.is_empty() { - return Err(PiCcsError::ProtocolError( - "missing prev Twist decoded cols lanes".into(), - )); - } - - let mut prev_total_oracles: Vec> = Vec::with_capacity(prev_decoded.lanes.len()); - let mut claimed_prev_total = K::ZERO; - for lane in prev_decoded.lanes.iter() { - let (oracle, claim) = TwistTotalIncOracleSparseTime::new( - lane.wa_bits.clone(), - lane.has_write.clone(), - lane.inc_at_write_addr.clone(), - r_addr, - ); - prev_total_oracles.push(Box::new(oracle)); - claimed_prev_total += claim; - } - let oracle_prev_total: Box = Box::new(SumRoundOracle::new(prev_total_oracles)); - - val_oracles.push(oracle_prev_total); - bind_claims.push((plan.bind_tags[claim_idx], claimed_prev_total)); - claimed_sums.push(claimed_prev_total); - claim_idx += 1; - - claimed_prev_inc_sums_total.push(Some(claimed_prev_total)); - } else { - claimed_prev_inc_sums_total.push(None); - } - } - - tr.append_message( - b"twist/val_eval/batch_start", - &(step.mem_instances.len() as u64).to_le_bytes(), - ); - tr.append_message(b"twist/val_eval/step_idx", &(step_idx as u64).to_le_bytes()); - bind_twist_val_eval_claim_sums(tr, &bind_claims); - - let mut claims: Vec> = val_oracles - .iter_mut() - .zip(claimed_sums.iter()) - .zip(plan.labels.iter()) - .map(|((oracle, sum), label)| BatchedClaim { - oracle: oracle.as_mut(), - claimed_sum: *sum, - label: *label, - }) - .collect(); - - let (r_val_out, per_claim_results) = - run_batched_sumcheck_prover_ds(tr, b"twist/val_eval_batch", step_idx, claims.as_mut_slice())?; - - if per_claim_results.len() != claim_count { - return Err(PiCcsError::ProtocolError(format!( - "twist val-eval results count mismatch (expected {}, got {})", - claim_count, - per_claim_results.len() - ))); - } - if r_val_out.len() != r_time.len() { - return Err(PiCcsError::ProtocolError(format!( - "twist val-eval r_val.len()={}, expected ell_n={}", - r_val_out.len(), - r_time.len() - ))); - } - r_val = r_val_out; - - for i in 0..n_mem { - let base = claims_per_mem * i; - twist_val_eval_proofs.push(twist::TwistValEvalProof { - claimed_inc_sum_lt: claimed_inc_sums_lt[i], - rounds_lt: per_claim_results[base].round_polys.clone(), - claimed_inc_sum_total: claimed_inc_sums_total[i], - rounds_total: per_claim_results[base + 1].round_polys.clone(), - claimed_prev_inc_sum_total: claimed_prev_inc_sums_total[i], - rounds_prev_total: has_prev.then(|| per_claim_results[base + 2].round_polys.clone()), - }); - } - - tr.append_message(b"twist/val_eval/batch_done", &[]); - } - - if step.lut_instances.is_empty() { - if !shout_addr_pre.claimed_sums.is_empty() || !shout_addr_pre.groups.is_empty() { - return Err(PiCcsError::ProtocolError( - "shout_addr_pre must be empty when there are no Shout instances".into(), - )); - } - } - - for _ in 0..step.lut_instances.len() { - proofs.push(MemOrLutProof::Shout(ShoutProofK::default())); - } - - for idx in 0..step.mem_instances.len() { - let mut proof = TwistProofK::default(); - proof.addr_pre = twist_pre - .get(idx) - .ok_or_else(|| PiCcsError::ProtocolError("missing Twist addr_pre".into()))? - .addr_pre - .clone(); - proof.val_eval = twist_val_eval_proofs.get(idx).cloned(); - - proofs.push(MemOrLutProof::Twist(proof)); - } - - if !step.mem_instances.is_empty() { - if r_val.len() != r_time.len() { - return Err(PiCcsError::ProtocolError(format!( - "twist val-eval r_val.len()={}, expected ell_n={}", - r_val.len(), - r_time.len() - ))); - } - - // In shared-bus mode, val-lane checks read bus openings from CPU ME claims at r_val. - // Emit CPU ME at r_val for current step (and previous step for rollover). - let (mcs_inst, mcs_wit) = &step.mcs; - let core_t = s.t(); - let cpu_claims_cur = ts::emit_me_claims_for_mats( - tr, - b"cpu_bus/me_digest_val", - params, - s, - core::slice::from_ref(&mcs_inst.c), - core::slice::from_ref(&mcs_wit.Z), - &r_val, - m_in, - )?; - if cpu_claims_cur.len() != 1 { - return Err(PiCcsError::ProtocolError(format!( - "expected exactly 1 CPU ME claim at r_val, got {}", - cpu_claims_cur.len() - ))); - } - let mut cpu_claims_cur = cpu_claims_cur; - crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( - params, - cpu_bus, - core_t, - &mcs_wit.Z, - &mut cpu_claims_cur[0], - )?; - cpu_me_claims_val.extend(cpu_claims_cur); - - if let Some(prev) = prev_step { - let (prev_mcs_inst, prev_mcs_wit) = &prev.mcs; - let cpu_claims_prev = ts::emit_me_claims_for_mats( - tr, - b"cpu_bus/me_digest_val", - params, - s, - core::slice::from_ref(&prev_mcs_inst.c), - core::slice::from_ref(&prev_mcs_wit.Z), - &r_val, - m_in, - )?; - if cpu_claims_prev.len() != 1 { - return Err(PiCcsError::ProtocolError(format!( - "expected exactly 1 prev CPU ME claim at r_val, got {}", - cpu_claims_prev.len() - ))); - } - let mut cpu_claims_prev = cpu_claims_prev; - crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( - params, - cpu_bus, - core_t, - &prev_mcs_wit.Z, - &mut cpu_claims_prev[0], - )?; - cpu_me_claims_val.extend(cpu_claims_prev); - } - } - - if step.mem_instances.is_empty() { - if !twist_val_eval_proofs.is_empty() { - return Err(PiCcsError::ProtocolError( - "twist val-eval proofs must be empty when no mem instances are present".into(), - )); - } - if !r_val.is_empty() { - return Err(PiCcsError::ProtocolError( - "twist r_val must be empty when no mem instances are present".into(), - )); - } - if !cpu_me_claims_val.is_empty() { - return Err(PiCcsError::ProtocolError( - "twist val-lane ME claims must be empty when no mem instances are present".into(), - )); - } - } else if cpu_me_claims_val.is_empty() { - return Err(PiCcsError::ProtocolError( - "twist val-eval requires non-empty val-lane ME claims".into(), - )); - } - - Ok(MemSidecarProof { - cpu_me_claims_val, - shout_addr_pre: shout_addr_pre.clone(), - proofs, - }) -} - -// ============================================================================ -// ============================================================================ -pub fn verify_route_a_memory_step( - tr: &mut Poseidon2Transcript, - cpu_bus: &BusLayout, - step: &StepInstanceBundle, - prev_step: Option<&StepInstanceBundle>, - ccs_out0: &MeInstance, - r_time: &[K], - r_cycle: &[K], - batched_final_values: &[K], - batched_claimed_sums: &[K], - claim_idx_start: usize, - mem_proof: &MemSidecarProof, - shout_pre: &[ShoutAddrPreVerifyData], - twist_pre: &[TwistAddrPreVerifyData], - step_idx: usize, -) -> Result { - let chi_cycle_at_r_time = eq_points(r_time, r_cycle); - if ccs_out0.r.as_slice() != r_time { - return Err(PiCcsError::ProtocolError( - "CPU ME output r mismatch (expected shared r_time)".into(), - )); - } - let has_prev = prev_step.is_some(); - if let Some(prev) = prev_step { - if prev.mem_insts.len() != step.mem_insts.len() { - return Err(PiCcsError::InvalidInput(format!( - "Twist rollover requires stable mem instance count: prev has {}, current has {}", - prev.mem_insts.len(), - step.mem_insts.len() - ))); - } - for (idx, (prev_inst, inst)) in prev.mem_insts.iter().zip(step.mem_insts.iter()).enumerate() { - if prev_inst.d != inst.d - || prev_inst.ell != inst.ell - || prev_inst.k != inst.k - || prev_inst.lanes != inst.lanes - { - return Err(PiCcsError::InvalidInput(format!( - "Twist rollover requires stable geometry at mem_idx={}: prev (k={}, d={}, ell={}, lanes={}) vs cur (k={}, d={}, ell={}, lanes={})", - idx, - prev_inst.k, - prev_inst.d, - prev_inst.ell, - prev_inst.lanes, - inst.k, - inst.d, - inst.ell, - inst.lanes - ))); - } - } - } - - for (idx, inst) in step.lut_insts.iter().enumerate() { - if !inst.comms.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "shared CPU bus requires metadata-only Shout instances (comms must be empty, lut_idx={idx})" - ))); - } - } - for (idx, inst) in step.mem_insts.iter().enumerate() { - if !inst.comms.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "shared CPU bus requires metadata-only Twist instances (comms must be empty, mem_idx={idx})" - ))); - } - } - if let Some(prev) = prev_step { - for (idx, inst) in prev.lut_insts.iter().enumerate() { - if !inst.comms.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "shared CPU bus requires metadata-only Shout instances (comms must be empty, prev lut_idx={idx})" - ))); - } - } - for (idx, inst) in prev.mem_insts.iter().enumerate() { - if !inst.comms.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "shared CPU bus requires metadata-only Twist instances (comms must be empty, prev mem_idx={idx})" - ))); - } - } - } - - let proofs_mem = &mem_proof.proofs; - - if cpu_bus.shout_cols.len() != step.lut_insts.len() || cpu_bus.twist_cols.len() != step.mem_insts.len() { - return Err(PiCcsError::InvalidInput( - "shared_cpu_bus layout mismatch for step (instance counts)".into(), - )); - } - - let bus_y_base_time = if cpu_bus.bus_cols > 0 { - ccs_out0 - .y_scalars - .len() - .checked_sub(cpu_bus.bus_cols) - .ok_or_else(|| PiCcsError::InvalidInput("CPU y_scalars too short for bus openings".into()))? - } else { - 0usize - }; - let claim_plan = RouteATimeClaimPlan::build(step, claim_idx_start)?; - if claim_plan.claim_idx_end > batched_final_values.len() { - return Err(PiCcsError::InvalidInput(format!( - "batched_final_values too short (need at least {}, have {})", - claim_plan.claim_idx_end, - batched_final_values.len() - ))); - } - if claim_plan.claim_idx_end > batched_claimed_sums.len() { - return Err(PiCcsError::InvalidInput(format!( - "batched_claimed_sums too short (need at least {}, have {})", - claim_plan.claim_idx_end, - batched_claimed_sums.len() - ))); - } - - let expected_proofs = step.lut_insts.len() + step.mem_insts.len(); - if proofs_mem.len() != expected_proofs { - return Err(PiCcsError::InvalidInput(format!( - "mem proof count mismatch (expected {}, got {})", - expected_proofs, - proofs_mem.len() - ))); - } - let total_shout_lanes: usize = step.lut_insts.iter().map(|inst| inst.lanes.max(1)).sum(); - if shout_pre.len() != total_shout_lanes { - return Err(PiCcsError::InvalidInput(format!( - "shout pre-time count mismatch (expected total_lanes={}, got {})", - total_shout_lanes, - shout_pre.len() - ))); - } - if twist_pre.len() != step.mem_insts.len() { - return Err(PiCcsError::InvalidInput(format!( - "twist pre-time count mismatch (expected {}, got {})", - step.mem_insts.len(), - twist_pre.len() - ))); - } - - let mut twist_time_openings: Vec = Vec::with_capacity(step.mem_insts.len()); - - // Shout instances first. - let mut shout_lane_base: usize = 0; - for (proof_idx, inst) in step.lut_insts.iter().enumerate() { - match &proofs_mem[proof_idx] { - MemOrLutProof::Shout(_proof) => {} - _ => return Err(PiCcsError::InvalidInput("expected Shout proof".into())), - } - - let ell_addr = inst.d * inst.ell; - let expected_lanes = inst.lanes.max(1); - - let inst_cols = cpu_bus - .shout_cols - .get(proof_idx) - .ok_or_else(|| PiCcsError::InvalidInput("shared_cpu_bus layout mismatch (shout)".into()))?; - if inst_cols.lanes.len() != expected_lanes { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at lut_idx={proof_idx}: bus shout lanes={} but instance expects {expected_lanes}", - inst_cols.lanes.len() - ))); - } - - struct ShoutLaneOpen { - addr_bits: Vec, - has_lookup: K, - val: K, - } - let mut lane_opens: Vec = Vec::with_capacity(expected_lanes); - for (lane_idx, shout_cols) in inst_cols.lanes.iter().enumerate() { - if shout_cols.addr_bits.end - shout_cols.addr_bits.start != ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at lut_idx={proof_idx}, lane_idx={lane_idx}: expected ell_addr={ell_addr}" - ))); - } - - let mut addr_bits_open = Vec::with_capacity(ell_addr); - for (_j, col_id) in shout_cols.addr_bits.clone().enumerate() { - addr_bits_open.push( - ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, col_id)) - .copied() - .ok_or_else(|| { - PiCcsError::ProtocolError("CPU y_scalars missing Shout addr_bits opening".into()) - })?, - ); - } - let has_lookup_open = ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, shout_cols.has_lookup)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Shout has_lookup opening".into()))?; - let val_open = ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, shout_cols.val)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Shout val opening".into()))?; - - lane_opens.push(ShoutLaneOpen { - addr_bits: addr_bits_open, - has_lookup: has_lookup_open, - val: val_open, - }); - } - - let shout_claims = claim_plan - .shout - .get(proof_idx) - .ok_or_else(|| PiCcsError::ProtocolError(format!("missing Shout claim schedule at index {}", proof_idx)))?; - if shout_claims.lanes.len() != expected_lanes { - return Err(PiCcsError::ProtocolError(format!( - "Shout claim schedule lane count mismatch at lut_idx={proof_idx}: expected {expected_lanes}, got {}", - shout_claims.lanes.len() - ))); - } - if shout_lane_base - .checked_add(expected_lanes) - .ok_or_else(|| PiCcsError::ProtocolError("shout lane index overflow".into()))? - > shout_pre.len() - { - return Err(PiCcsError::ProtocolError("Shout pre-time lane indexing drift".into())); - } - - // Route A Shout ordering in batched_time: - // - value (time rounds only) per lane - // - adapter (time rounds only) per lane - // - aggregated bitness for (addr_bits, has_lookup) - { - let mut opens: Vec = Vec::with_capacity(expected_lanes * (ell_addr + 1)); - for lane in lane_opens.iter() { - opens.extend_from_slice(&lane.addr_bits); - opens.push(lane.has_lookup); - } - let weights = bitness_weights(r_cycle, opens.len(), 0x5348_4F55_54u64 + proof_idx as u64); - let mut acc = K::ZERO; - for (w, b) in weights.iter().zip(opens.iter()) { - acc += *w * *b * (*b - K::ONE); - } - let expected = chi_cycle_at_r_time * acc; - if expected != batched_final_values[shout_claims.bitness] { - return Err(PiCcsError::ProtocolError( - "shout/bitness terminal value mismatch".into(), - )); - } - } - - for (lane_idx, lane) in lane_opens.iter().enumerate() { - let pre = shout_pre.get(shout_lane_base + lane_idx).ok_or_else(|| { - PiCcsError::InvalidInput(format!( - "missing pre-time Shout lane data at index {}", - shout_lane_base + lane_idx - )) - })?; - let lane_claims = shout_claims - .lanes - .get(lane_idx) - .ok_or_else(|| PiCcsError::ProtocolError("shout claim schedule lane idx drift".into()))?; - - let value_claim = batched_claimed_sums[lane_claims.value]; - let value_final = batched_final_values[lane_claims.value]; - let adapter_claim = batched_claimed_sums[lane_claims.adapter]; - let adapter_final = batched_final_values[lane_claims.adapter]; - - let expected_value_final = chi_cycle_at_r_time * lane.has_lookup * lane.val; - if expected_value_final != value_final { - return Err(PiCcsError::ProtocolError("shout value terminal value mismatch".into())); - } - - let eq_addr = eq_bits_prod(&lane.addr_bits, &pre.r_addr)?; - let expected_adapter_final = chi_cycle_at_r_time * lane.has_lookup * eq_addr; - if expected_adapter_final != adapter_final { - return Err(PiCcsError::ProtocolError( - "shout adapter terminal value mismatch".into(), - )); - } - - if value_claim != pre.addr_claim_sum { - return Err(PiCcsError::ProtocolError( - "shout value claimed sum != addr claimed sum".into(), - )); - } - - if pre.is_active { - let expected_addr_final = pre.table_eval_at_r_addr * adapter_claim; - if expected_addr_final != pre.addr_final { - return Err(PiCcsError::ProtocolError("shout addr terminal value mismatch".into())); - } - } else { - // If we skipped the addr-pre sumcheck, the only sound case is "no lookups". - // Enforce this by requiring the addr claim + adapter claim to be zero. - if pre.addr_claim_sum != K::ZERO { - return Err(PiCcsError::ProtocolError( - "shout addr-pre skipped but addr claim is nonzero".into(), - )); - } - if adapter_claim != K::ZERO { - return Err(PiCcsError::ProtocolError( - "shout addr-pre skipped but adapter claim is nonzero".into(), - )); - } - if pre.addr_final != K::ZERO { - return Err(PiCcsError::ProtocolError( - "shout addr-pre skipped but addr_final is nonzero".into(), - )); - } - } - } - - shout_lane_base += expected_lanes; - } - if shout_lane_base != shout_pre.len() { - return Err(PiCcsError::ProtocolError( - "shout pre-time lanes not fully consumed".into(), - )); - } - - // Twist instances next. - let proof_mem_offset = step.lut_insts.len(); - - // -------------------------------------------------------------------- - // Twist time checks at addr-pre `r_addr`. - // -------------------------------------------------------------------- - for (i_mem, inst) in step.mem_insts.iter().enumerate() { - let twist_proof = match &proofs_mem[proof_mem_offset + i_mem] { - MemOrLutProof::Twist(proof) => proof, - _ => return Err(PiCcsError::InvalidInput("expected Twist proof".into())), - }; - let layout = inst.twist_layout(); - let ell_addr = layout - .lanes - .get(0) - .ok_or_else(|| PiCcsError::InvalidInput("TwistWitnessLayout has no lanes".into()))? - .ell_addr; - - let twist_inst_cols = cpu_bus - .twist_cols - .get(i_mem) - .ok_or_else(|| PiCcsError::InvalidInput("shared_cpu_bus layout mismatch (twist)".into()))?; - let expected_lanes = inst.lanes.max(1); - if twist_inst_cols.lanes.len() != expected_lanes { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at mem_idx={i_mem}: expected lanes={expected_lanes}, got {}", - twist_inst_cols.lanes.len() - ))); - } - - struct TwistLaneTimeOpen { - ra_bits: Vec, - wa_bits: Vec, - has_read: K, - has_write: K, - wv: K, - rv: K, - inc: K, - } - - let mut lane_opens: Vec = Vec::with_capacity(twist_inst_cols.lanes.len()); - for (lane_idx, twist_cols) in twist_inst_cols.lanes.iter().enumerate() { - if twist_cols.ra_bits.end - twist_cols.ra_bits.start != ell_addr - || twist_cols.wa_bits.end - twist_cols.wa_bits.start != ell_addr - { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at mem_idx={i_mem}, lane={lane_idx}: expected ell_addr={ell_addr}" - ))); - } - - let mut ra_bits_open = Vec::with_capacity(ell_addr); - for col_id in twist_cols.ra_bits.clone() { - ra_bits_open.push( - ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, col_id)) - .copied() - .ok_or_else(|| { - PiCcsError::ProtocolError("CPU y_scalars missing Twist ra_bits opening".into()) - })?, - ); - } - let mut wa_bits_open = Vec::with_capacity(ell_addr); - for col_id in twist_cols.wa_bits.clone() { - wa_bits_open.push( - ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, col_id)) - .copied() - .ok_or_else(|| { - PiCcsError::ProtocolError("CPU y_scalars missing Twist wa_bits opening".into()) - })?, - ); - } - - let has_read_open = ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.has_read)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist has_read opening".into()))?; - let has_write_open = ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.has_write)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist has_write opening".into()))?; - let wv_open = ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.wv)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist wv opening".into()))?; - let rv_open = ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.rv)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist rv opening".into()))?; - let inc_write_open = ccs_out0 - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.inc)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist inc opening".into()))?; - - lane_opens.push(TwistLaneTimeOpen { - ra_bits: ra_bits_open, - wa_bits: wa_bits_open, - has_read: has_read_open, - has_write: has_write_open, - wv: wv_open, - rv: rv_open, - inc: inc_write_open, - }); - } - - let pre = twist_pre - .get(i_mem) - .ok_or_else(|| PiCcsError::InvalidInput(format!("missing Twist pre-time data at index {}", i_mem)))?; - let r_addr = &pre.r_addr; - if r_addr.len() != ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "Twist r_addr.len()={}, expected ell_addr={}", - r_addr.len(), - ell_addr - ))); - } - - let twist_claims = claim_plan - .twist - .get(i_mem) - .ok_or_else(|| PiCcsError::ProtocolError(format!("missing Twist claim schedule at index {}", i_mem)))?; - - // Route A Twist ordering in batched_time: - // - read_check (time rounds only) - // - write_check (time rounds only) - // - bitness for ra_bits then wa_bits then has_read then has_write (time-only) - let read_check_claim = batched_claimed_sums[twist_claims.read_check]; - let read_check_final = batched_final_values[twist_claims.read_check]; - let write_check_claim = batched_claimed_sums[twist_claims.write_check]; - let write_check_final = batched_final_values[twist_claims.write_check]; - - if read_check_claim != pre.read_check_claim_sum { - return Err(PiCcsError::ProtocolError( - "twist read_check claimed sum != addr-pre final".into(), - )); - } - if write_check_claim != pre.write_check_claim_sum { - return Err(PiCcsError::ProtocolError( - "twist write_check claimed sum != addr-pre final".into(), - )); - } - - // Aggregated bitness terminal check (ra_bits, wa_bits, has_read, has_write). - { - let mut opens: Vec = Vec::with_capacity(expected_lanes * (2 * ell_addr + 2)); - for lane in lane_opens.iter() { - opens.extend_from_slice(&lane.ra_bits); - opens.extend_from_slice(&lane.wa_bits); - opens.push(lane.has_read); - opens.push(lane.has_write); - } - let weights = bitness_weights(r_cycle, opens.len(), 0x5457_4953_54u64 + i_mem as u64); - let mut acc = K::ZERO; - for (w, b) in weights.iter().zip(opens.iter()) { - acc += *w * *b * (*b - K::ONE); - } - let expected = chi_cycle_at_r_time * acc; - if expected != batched_final_values[twist_claims.bitness] { - return Err(PiCcsError::ProtocolError( - "twist/bitness terminal value mismatch".into(), - )); - } - } - - let val_eval = twist_proof - .val_eval - .as_ref() - .ok_or_else(|| PiCcsError::InvalidInput("Twist(Route A): missing val_eval proof".into()))?; - - let init_at_r_addr = eval_init_at_r_addr(&inst.init, inst.k, r_addr)?; - let claimed_val = init_at_r_addr + val_eval.claimed_inc_sum_lt; - - // Terminal checks for read_check / write_check at (r_time, r_addr). - let mut expected_read_check_final = K::ZERO; - let mut expected_write_check_final = K::ZERO; - for lane in lane_opens.iter() { - let read_eq_addr = eq_bits_prod(&lane.ra_bits, r_addr)?; - expected_read_check_final += chi_cycle_at_r_time * lane.has_read * (claimed_val - lane.rv) * read_eq_addr; - - let write_eq_addr = eq_bits_prod(&lane.wa_bits, r_addr)?; - expected_write_check_final += - chi_cycle_at_r_time * lane.has_write * (lane.wv - claimed_val - lane.inc) * write_eq_addr; - } - if expected_read_check_final != read_check_final { - return Err(PiCcsError::ProtocolError( - "twist/read_check terminal value mismatch".into(), - )); - } - - if expected_write_check_final != write_check_final { - return Err(PiCcsError::ProtocolError( - "twist/write_check terminal value mismatch".into(), - )); - } - - twist_time_openings.push(TwistTimeLaneOpenings { - lanes: lane_opens - .into_iter() - .map(|lane| TwistTimeLaneOpeningsLane { - wa_bits: lane.wa_bits, - has_write: lane.has_write, - inc_at_write_addr: lane.inc, - }) - .collect(), - }); - } - - // -------------------------------------------------------------------- - // Phase 2: Verify batched Twist val-eval sum-check, deriving shared r_val. - // -------------------------------------------------------------------- - let mut r_val: Vec = Vec::new(); - let mut val_eval_finals: Vec = Vec::new(); - if !step.mem_insts.is_empty() { - let plan = crate::memory_sidecar::claim_plan::TwistValEvalClaimPlan::build(step.mem_insts.iter(), has_prev); - let claim_count = plan.claim_count; - - let mut per_claim_rounds: Vec>> = Vec::with_capacity(claim_count); - let mut per_claim_sums: Vec = Vec::with_capacity(claim_count); - let mut bind_claims: Vec<(u8, K)> = Vec::with_capacity(claim_count); - let mut claim_idx = 0usize; - - for (i_mem, _inst) in step.mem_insts.iter().enumerate() { - let twist_proof = match &proofs_mem[proof_mem_offset + i_mem] { - MemOrLutProof::Twist(proof) => proof, - _ => return Err(PiCcsError::InvalidInput("expected Twist proof".into())), - }; - let val = twist_proof - .val_eval - .as_ref() - .ok_or_else(|| PiCcsError::InvalidInput("Twist(Route A): missing val_eval proof".into()))?; - - per_claim_rounds.push(val.rounds_lt.clone()); - per_claim_sums.push(val.claimed_inc_sum_lt); - bind_claims.push((plan.bind_tags[claim_idx], val.claimed_inc_sum_lt)); - claim_idx += 1; - - per_claim_rounds.push(val.rounds_total.clone()); - per_claim_sums.push(val.claimed_inc_sum_total); - bind_claims.push((plan.bind_tags[claim_idx], val.claimed_inc_sum_total)); - claim_idx += 1; - - if has_prev { - let prev_total = val.claimed_prev_inc_sum_total.ok_or_else(|| { - PiCcsError::InvalidInput("Twist(Route A): missing claimed_prev_inc_sum_total".into()) - })?; - let prev_rounds = val - .rounds_prev_total - .clone() - .ok_or_else(|| PiCcsError::InvalidInput("Twist(Route A): missing rounds_prev_total".into()))?; - per_claim_rounds.push(prev_rounds); - per_claim_sums.push(prev_total); - bind_claims.push((plan.bind_tags[claim_idx], prev_total)); - claim_idx += 1; - } else if val.claimed_prev_inc_sum_total.is_some() || val.rounds_prev_total.is_some() { - return Err(PiCcsError::InvalidInput( - "Twist(Route A): rollover fields present but prev_step is None".into(), - )); - } - } - - tr.append_message( - b"twist/val_eval/batch_start", - &(step.mem_insts.len() as u64).to_le_bytes(), - ); - tr.append_message(b"twist/val_eval/step_idx", &(step_idx as u64).to_le_bytes()); - bind_twist_val_eval_claim_sums(tr, &bind_claims); - - let (r_val_out, finals_out, ok) = verify_batched_sumcheck_rounds_ds( - tr, - b"twist/val_eval_batch", - step_idx, - &per_claim_rounds, - &per_claim_sums, - &plan.labels, - &plan.degree_bounds, - ); - if !ok { - return Err(PiCcsError::SumcheckError( - "twist val-eval batched sumcheck invalid".into(), - )); - } - if r_val_out.len() != r_time.len() { - return Err(PiCcsError::ProtocolError(format!( - "twist val-eval r_val.len()={}, expected ell_n={}", - r_val_out.len(), - r_time.len() - ))); - } - if finals_out.len() != claim_count { - return Err(PiCcsError::ProtocolError(format!( - "twist val-eval finals.len()={}, expected {}", - finals_out.len(), - claim_count - ))); - } - r_val = r_val_out; - val_eval_finals = finals_out; - - tr.append_message(b"twist/val_eval/batch_done", &[]); - } - - // Verify val-eval terminal identity against CPU ME openings at r_val. - let lt = if step.mem_insts.is_empty() { - if !r_val.is_empty() { - return Err(PiCcsError::ProtocolError( - "twist val-eval produced r_val but no mem instances are present".into(), - )); - } - K::ZERO - } else { - if r_val.len() != r_time.len() { - return Err(PiCcsError::ProtocolError(format!( - "twist val-eval r_val.len()={}, expected ell_n={}", - r_val.len(), - r_time.len() - ))); - } - lt_eval(&r_val, r_time) - }; - - let (cpu_me_val_cur, cpu_me_val_prev, bus_y_base_val) = if step.mem_insts.is_empty() { - if !mem_proof.cpu_me_claims_val.is_empty() { - return Err(PiCcsError::InvalidInput( - "proof contains val-lane CPU ME claims with no Twist instances".into(), - )); - } - (None, None, 0usize) - } else { - let expected = 1usize + usize::from(has_prev); - if mem_proof.cpu_me_claims_val.len() != expected { - return Err(PiCcsError::InvalidInput(format!( - "shared bus expects {} CPU ME claim(s) at r_val, got {}", - expected, - mem_proof.cpu_me_claims_val.len() - ))); - } - - let cpu_me_cur = mem_proof - .cpu_me_claims_val - .get(0) - .ok_or_else(|| PiCcsError::ProtocolError("missing CPU ME claim at r_val".into()))?; - if cpu_me_cur.r.as_slice() != r_val { - return Err(PiCcsError::ProtocolError( - "CPU ME(val) r mismatch (expected r_val)".into(), - )); - } - if cpu_me_cur.c != step.mcs_inst.c { - return Err(PiCcsError::ProtocolError( - "CPU ME(val) commitment mismatch (current step)".into(), - )); - } - let cpu_me_prev = if has_prev { - let prev_inst = - prev_step.ok_or_else(|| PiCcsError::ProtocolError("prev_step missing with has_prev=true".into()))?; - let cpu_me_prev = mem_proof - .cpu_me_claims_val - .get(1) - .ok_or_else(|| PiCcsError::ProtocolError("missing prev CPU ME claim at r_val".into()))?; - if cpu_me_prev.r.as_slice() != r_val { - return Err(PiCcsError::ProtocolError( - "CPU ME(val/prev) r mismatch (expected r_val)".into(), - )); - } - if cpu_me_prev.c != prev_inst.mcs_inst.c { - return Err(PiCcsError::ProtocolError("CPU ME(val/prev) commitment mismatch".into())); - } - Some(cpu_me_prev) - } else { - None - }; - - let bus_y_base_val = cpu_me_cur - .y_scalars - .len() - .checked_sub(cpu_bus.bus_cols) - .ok_or_else(|| PiCcsError::InvalidInput("CPU y_scalars too short for bus openings".into()))?; - - (Some(cpu_me_cur), cpu_me_prev, bus_y_base_val) - }; - - for (i_mem, inst) in step.mem_insts.iter().enumerate() { - let twist_proof = match &proofs_mem[proof_mem_offset + i_mem] { - MemOrLutProof::Twist(proof) => proof, - _ => return Err(PiCcsError::InvalidInput("expected Twist proof".into())), - }; - let val_eval = twist_proof - .val_eval - .as_ref() - .ok_or_else(|| PiCcsError::InvalidInput("Twist(Route A): missing val_eval proof".into()))?; - let layout = inst.twist_layout(); - let ell_addr = layout - .lanes - .get(0) - .ok_or_else(|| PiCcsError::InvalidInput("TwistWitnessLayout has no lanes".into()))? - .ell_addr; - - let cpu_me_cur = - cpu_me_val_cur.ok_or_else(|| PiCcsError::ProtocolError("missing CPU ME claim at r_val".into()))?; - - let twist_inst_cols = cpu_bus - .twist_cols - .get(i_mem) - .ok_or_else(|| PiCcsError::InvalidInput("shared_cpu_bus layout mismatch (twist)".into()))?; - let expected_lanes = inst.lanes.max(1); - if twist_inst_cols.lanes.len() != expected_lanes { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at mem_idx={i_mem}: expected lanes={expected_lanes}, got {}", - twist_inst_cols.lanes.len() - ))); - } - - let r_addr = twist_pre - .get(i_mem) - .ok_or_else(|| PiCcsError::InvalidInput(format!("missing Twist pre-time data at index {}", i_mem)))? - .r_addr - .as_slice(); - - let mut inc_at_r_addr_val = K::ZERO; - for (lane_idx, twist_cols) in twist_inst_cols.lanes.iter().enumerate() { - if twist_cols.wa_bits.end - twist_cols.wa_bits.start != ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at mem_idx={i_mem}, lane={lane_idx}: expected ell_addr={ell_addr}" - ))); - } - - let mut wa_bits_val_open = Vec::with_capacity(ell_addr); - for col_id in twist_cols.wa_bits.clone() { - wa_bits_val_open.push( - cpu_me_cur - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_val, col_id)) - .copied() - .ok_or_else(|| { - PiCcsError::ProtocolError("CPU y_scalars missing wa_bits(val) opening".into()) - })?, - ); - } - let has_write_val_open = cpu_me_cur - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_val, twist_cols.has_write)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing has_write(val) opening".into()))?; - let inc_at_write_addr_val_open = cpu_me_cur - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_val, twist_cols.inc)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing inc(val) opening".into()))?; - - let eq_wa_val = eq_bits_prod(&wa_bits_val_open, r_addr)?; - inc_at_r_addr_val += has_write_val_open * inc_at_write_addr_val_open * eq_wa_val; - } - - let expected_lt_final = inc_at_r_addr_val * lt; - let claims_per_mem = if has_prev { 3 } else { 2 }; - let base = claims_per_mem * i_mem; - if expected_lt_final != val_eval_finals[base] { - return Err(PiCcsError::ProtocolError( - "twist/val_eval_lt terminal value mismatch".into(), - )); - } - let expected_total_final = inc_at_r_addr_val; - if expected_total_final != val_eval_finals[base + 1] { - return Err(PiCcsError::ProtocolError( - "twist/val_eval_total terminal value mismatch".into(), - )); - } - - if has_prev { - let prev = - prev_step.ok_or_else(|| PiCcsError::ProtocolError("prev_step missing with has_prev=true".into()))?; - let prev_inst = prev - .mem_insts - .get(i_mem) - .ok_or_else(|| PiCcsError::ProtocolError("missing prev mem instance".into()))?; - let cpu_me_prev = cpu_me_val_prev - .ok_or_else(|| PiCcsError::ProtocolError("missing prev CPU ME claim at r_val".into()))?; - - // Terminal check for prev-total: uses previous-step openings at current r_val. - let mut inc_at_r_addr_prev = K::ZERO; - for (lane_idx, twist_cols) in twist_inst_cols.lanes.iter().enumerate() { - if twist_cols.wa_bits.end - twist_cols.wa_bits.start != ell_addr { - return Err(PiCcsError::InvalidInput(format!( - "shared_cpu_bus layout mismatch at mem_idx={i_mem}, lane={lane_idx}: expected ell_addr={ell_addr}" - ))); - } - - let mut wa_bits_prev_open = Vec::with_capacity(ell_addr); - for col_id in twist_cols.wa_bits.clone() { - wa_bits_prev_open.push( - cpu_me_prev - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_val, col_id)) - .copied() - .ok_or_else(|| { - PiCcsError::ProtocolError("CPU y_scalars missing wa_bits(prev) opening".into()) - })?, - ); - } - let has_write_prev_open = cpu_me_prev - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_val, twist_cols.has_write)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing has_write(prev) opening".into()))?; - let inc_prev_open = cpu_me_prev - .y_scalars - .get(cpu_bus.y_scalar_index(bus_y_base_val, twist_cols.inc)) - .copied() - .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing inc(prev) opening".into()))?; - - let eq_wa_prev = eq_bits_prod(&wa_bits_prev_open, r_addr)?; - inc_at_r_addr_prev += has_write_prev_open * inc_prev_open * eq_wa_prev; - } - if inc_at_r_addr_prev != val_eval_finals[base + 2] { - return Err(PiCcsError::ProtocolError( - "twist/rollover_prev_total terminal value mismatch".into(), - )); - } - - // Enforce rollover equation: Init_i(r_addr) == Init_{i-1}(r_addr) + PrevTotal(i). - let claimed_prev_total = val_eval - .claimed_prev_inc_sum_total - .ok_or_else(|| PiCcsError::ProtocolError("twist rollover missing claimed_prev_inc_sum_total".into()))?; - let init_prev_at_r_addr = eval_init_at_r_addr(&prev_inst.init, prev_inst.k, r_addr)?; - let init_cur_at_r_addr = eval_init_at_r_addr(&inst.init, inst.k, r_addr)?; - if init_cur_at_r_addr != init_prev_at_r_addr + claimed_prev_total { - return Err(PiCcsError::ProtocolError("twist rollover init check failed".into())); - } - } - } - - Ok(RouteAMemoryVerifyOutput { - claim_idx_end: claim_plan.claim_idx_end, - twist_time_openings, - }) -} +use p3_field::PrimeField64; +use std::collections::{BTreeMap, BTreeSet}; + +#[path = "memory/addr_pre_proofs.rs"] +mod addr_pre_proofs; +#[path = "memory/event_table_context.rs"] +mod event_table_context; +#[path = "memory/route_a_claim_builders.rs"] +mod route_a_claim_builders; +#[path = "memory/route_a_claims.rs"] +mod route_a_claims; +#[path = "memory/route_a_finalize.rs"] +mod route_a_finalize; +#[path = "memory/route_a_oracles.rs"] +mod route_a_oracles; +#[path = "memory/route_a_terminal_checks.rs"] +mod route_a_terminal_checks; +#[path = "memory/route_a_verify.rs"] +mod route_a_verify; +#[path = "memory/sparse_oracles_and_twist_pre.rs"] +mod sparse_oracles_and_twist_pre; +#[path = "memory/transcript_and_common.rs"] +mod transcript_and_common; + +pub use addr_pre_proofs::{verify_shout_addr_pre_time, verify_twist_addr_pre_time}; +pub use route_a_verify::verify_route_a_memory_step; +pub use transcript_and_common::{absorb_step_memory, TwistTimeLaneOpenings}; + +pub(crate) use addr_pre_proofs::*; +pub(crate) use event_table_context::*; +pub(crate) use route_a_claim_builders::*; +pub(crate) use route_a_claims::*; +pub(crate) use route_a_finalize::*; +pub(crate) use route_a_oracles::*; +pub(crate) use route_a_terminal_checks::*; +pub(crate) use sparse_oracles_and_twist_pre::*; +pub(crate) use transcript_and_common::*; diff --git a/crates/neo-fold/src/memory_sidecar/memory/addr_pre_proofs.rs b/crates/neo-fold/src/memory_sidecar/memory/addr_pre_proofs.rs new file mode 100644 index 00000000..01021860 --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/addr_pre_proofs.rs @@ -0,0 +1,682 @@ +use super::*; + +pub(crate) fn prove_shout_addr_pre_time( + tr: &mut Poseidon2Transcript, + params: &NeoParams, + step: &StepWitnessBundle, + cpu_bus: &BusLayout, + ell_n: usize, + r_cycle: &[K], + step_idx: usize, +) -> Result { + if step.lut_instances.is_empty() { + return Ok(ShoutAddrPreBatchProverData { + addr_pre: ShoutAddrPreProof::default(), + decoded: Vec::new(), + }); + } + + let pow2_cycle = 1usize << ell_n; + let n_lut = step.lut_instances.len(); + let total_lanes: usize = step + .lut_instances + .iter() + .map(|(inst, _)| inst.lanes.max(1)) + .sum(); + + let mut decoded_cols: Vec = Vec::with_capacity(n_lut); + let mut claimed_sums: Vec = vec![K::ZERO; total_lanes]; + + struct AddrPreGroupBuilder { + active_lanes: Vec, + active_claimed_sums: Vec, + addr_oracles: Vec>, + } + + // Group Shout addr-pre claims by `ell_addr` so we can run one batched sumcheck per group. + let mut groups: std::collections::BTreeMap = std::collections::BTreeMap::new(); + + let mut flat_lane_idx: usize = 0; + let bus = cpu_bus; + let cpu_z_k = crate::memory_sidecar::cpu_bus::decode_cpu_z_to_k(params, &step.mcs.1.Z); + if bus.shout_cols.len() != step.lut_instances.len() || bus.twist_cols.len() != step.mem_instances.len() { + return Err(PiCcsError::InvalidInput( + "shared_cpu_bus layout mismatch for step (instance counts)".into(), + )); + } + let mut addr_range_counts = std::collections::HashMap::<(usize, usize), usize>::new(); + for inst_cols in bus.shout_cols.iter() { + for lane_cols in inst_cols.lanes.iter() { + let key = (lane_cols.addr_bits.start, lane_cols.addr_bits.end); + *addr_range_counts.entry(key).or_insert(0) += 1; + } + } + // Shared-bus trace mode can have many lookup families reusing the same bus columns + // (e.g. decode/width selector+addr groups and opcode addr groups). Cache sparse + // decodes by (col_id, steps) to avoid rebuilding identical SparseIdxVec values. + let mut full_col_sparse_cache: std::collections::HashMap<(usize, usize), SparseIdxVec> = + std::collections::HashMap::new(); + let mut has_lookup_cache: std::collections::HashMap<(usize, usize), (SparseIdxVec, Vec, bool)> = + std::collections::HashMap::new(); + + let mut decode_full_col = |col_id: usize, steps: usize| -> Result, PiCcsError> { + if let Some(cached) = full_col_sparse_cache.get(&(col_id, steps)) { + return Ok(cached.clone()); + } + let decoded = + crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col(&cpu_z_k, bus, col_id, steps, pow2_cycle)?; + full_col_sparse_cache.insert((col_id, steps), decoded.clone()); + Ok(decoded) + }; + + for (idx, (lut_inst, _lut_wit)) in step.lut_instances.iter().enumerate() { + neo_memory::addr::validate_shout_bit_addressing(lut_inst)?; + if lut_inst.steps > pow2_cycle { + return Err(PiCcsError::InvalidInput(format!( + "Shout(Route A): steps={} exceeds 2^ell_cycle={pow2_cycle}", + lut_inst.steps + ))); + } + + let z = &cpu_z_k; + let inst_ell_addr = lut_inst.d * lut_inst.ell; + if matches!( + lut_inst.table_spec, + Some(LutTableSpec::RiscvOpcodePacked { .. } | LutTableSpec::RiscvOpcodeEventTablePacked { .. }) + ) { + return Err(PiCcsError::InvalidInput( + "packed RISC-V Shout table specs are not supported on the shared CPU bus".into(), + )); + } + let inst_ell_addr_u32 = u32::try_from(inst_ell_addr) + .map_err(|_| PiCcsError::InvalidInput("Shout(Route A): ell_addr overflows u32".into()))?; + groups + .entry(inst_ell_addr_u32) + .or_insert_with(|| AddrPreGroupBuilder { + active_lanes: Vec::new(), + active_claimed_sums: Vec::new(), + addr_oracles: Vec::new(), + }); + let inst_cols = bus.shout_cols.get(idx).ok_or_else(|| { + PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch: missing shout_cols for lut_idx={idx}" + )) + })?; + let expected_lanes = lut_inst.lanes.max(1); + if inst_cols.lanes.len() != expected_lanes { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at lut_idx={idx}: shout lanes={} but instance expects {}", + inst_cols.lanes.len(), + expected_lanes + ))); + } + + let mut lanes: Vec = Vec::with_capacity(expected_lanes); + + for (lane_idx, shout_cols) in inst_cols.lanes.iter().enumerate() { + if shout_cols.addr_bits.end - shout_cols.addr_bits.start != inst_ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at lut_idx={idx}, lane_idx={lane_idx}: expected ell_addr={inst_ell_addr}" + ))); + } + let addr_key = (shout_cols.addr_bits.start, shout_cols.addr_bits.end); + let shared_addr_group = addr_range_counts.get(&addr_key).copied().unwrap_or(0) > 1; + + let (has_lookup, active_js, has_any_lookup) = if let Some((cached_has, cached_js, cached_any)) = + has_lookup_cache.get(&(shout_cols.has_lookup, lut_inst.steps)) + { + (cached_has.clone(), cached_js.clone(), *cached_any) + } else { + let has_lookup = decode_full_col(shout_cols.has_lookup, lut_inst.steps)?; + let has_any_lookup = has_lookup + .entries() + .iter() + .any(|&(_t, gate)| gate != K::ZERO); + let active_js: Vec = if has_any_lookup { + let m_in = bus.m_in; + let mut out: Vec = Vec::with_capacity(has_lookup.entries().len()); + for &(t, gate) in has_lookup.entries() { + if gate == K::ZERO { + continue; + } + let j = t.checked_sub(m_in).ok_or_else(|| { + PiCcsError::InvalidInput(format!( + "Shout(Route A): has_lookup time index underflow: t={t} < m_in={m_in}" + )) + })?; + if j >= lut_inst.steps { + return Err(PiCcsError::ProtocolError(format!( + "Shout(Route A): has_lookup time index out of range: j={j} >= steps={}", + lut_inst.steps + ))); + } + out.push(j); + } + out + } else { + Vec::new() + }; + has_lookup_cache.insert( + (shout_cols.has_lookup, lut_inst.steps), + (has_lookup.clone(), active_js.clone(), has_any_lookup), + ); + (has_lookup, active_js, has_any_lookup) + }; + + let addr_bits: Vec> = if shared_addr_group { + let mut out = Vec::with_capacity(inst_ell_addr); + for col_id in shout_cols.addr_bits.clone() { + out.push(decode_full_col(col_id, lut_inst.steps)?); + } + out + } else if has_any_lookup { + let mut out = Vec::with_capacity(inst_ell_addr); + for col_id in shout_cols.addr_bits.clone() { + out.push(crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col_at_js( + z, bus, col_id, &active_js, pow2_cycle, + )?); + } + out + } else { + vec![SparseIdxVec::new(pow2_cycle); inst_ell_addr] + }; + + let val = if has_any_lookup { + crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col_at_js( + z, + bus, + shout_cols.primary_val(), + &active_js, + pow2_cycle, + )? + } else { + SparseIdxVec::new(pow2_cycle) + }; + + if has_any_lookup { + let (addr_oracle, lane_sum): (Box, K) = match &lut_inst.table_spec { + None => { + let table_k: Vec = lut_inst.table.iter().map(|&v| v.into()).collect(); + let (o, sum) = + AddressLookupOracle::new(&addr_bits, &has_lookup, &table_k, r_cycle, inst_ell_addr); + (Box::new(o), sum) + } + Some(LutTableSpec::RiscvOpcode { opcode, xlen }) => { + let (o, sum) = RiscvAddressLookupOracleSparse::new_sparse_time( + *opcode, + *xlen, + &addr_bits, + &has_lookup, + r_cycle, + )?; + (Box::new(o), sum) + } + Some(LutTableSpec::RiscvOpcodePacked { .. }) => { + return Err(PiCcsError::InvalidInput( + "packed RISC-V Shout table specs are not supported on the shared CPU bus".into(), + )); + } + Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. }) => { + return Err(PiCcsError::InvalidInput( + "packed RISC-V Shout table specs are not supported on the shared CPU bus".into(), + )); + } + Some(LutTableSpec::IdentityU32) => { + let (o, sum) = IdentityAddressLookupOracleSparse::new_sparse_time( + inst_ell_addr, + &addr_bits, + &has_lookup, + r_cycle, + )?; + (Box::new(o), sum) + } + }; + + claimed_sums[flat_lane_idx] = lane_sum; + let lane_idx_u32 = u32::try_from(flat_lane_idx) + .map_err(|_| PiCcsError::InvalidInput("Shout(Route A): lane index overflow".into()))?; + let group = groups + .get_mut(&inst_ell_addr_u32) + .ok_or_else(|| PiCcsError::ProtocolError("Shout(Route A): missing ell_addr group".into()))?; + group.active_lanes.push(lane_idx_u32); + group.active_claimed_sums.push(lane_sum); + group.addr_oracles.push(addr_oracle); + } + + lanes.push(ShoutLaneSparseCols { + addr_bits, + has_lookup, + val, + }); + flat_lane_idx += 1; + } + + let decoded = ShoutDecodedColsSparse { lanes }; + + decoded_cols.push(decoded); + } + if flat_lane_idx != total_lanes { + return Err(PiCcsError::ProtocolError(format!( + "Shout(Route A): flat lane indexing drift (got {flat_lane_idx}, expected {total_lanes})" + ))); + } + + let labels_all: Vec<&'static [u8]> = vec![b"shout/addr_pre".as_slice(); total_lanes]; + tr.append_message(b"shout/addr_pre_time/step_idx", &(step_idx as u64).to_le_bytes()); + bind_batched_claim_sums(tr, b"shout/addr_pre_time/claimed_sums", &claimed_sums, &labels_all); + + let mut group_proofs: Vec> = Vec::with_capacity(groups.len()); + for (group_idx, (ell_addr, mut group)) in groups.into_iter().enumerate() { + tr.append_message(b"shout/addr_pre_time/group_idx", &(group_idx as u64).to_le_bytes()); + tr.append_message(b"shout/addr_pre_time/group_ell_addr", &(ell_addr as u64).to_le_bytes()); + + let (r_addr, round_polys) = if group.active_lanes.is_empty() { + // No active lanes in this `ell_addr` group; sample an arbitrary `r_addr` without running sumcheck. + tr.append_message(b"shout/addr_pre_time/no_sumcheck", &(step_idx as u64).to_le_bytes()); + tr.append_message( + b"shout/addr_pre_time/no_sumcheck/ell_addr", + &(ell_addr as u64).to_le_bytes(), + ); + ( + ts::sample_ext_point( + tr, + b"shout/addr_pre_time/no_sumcheck/r_addr", + b"shout/addr_pre_time/no_sumcheck/r_addr/0", + b"shout/addr_pre_time/no_sumcheck/r_addr/1", + ell_addr as usize, + ), + Vec::new(), + ) + } else { + let labels_active: Vec<&'static [u8]> = vec![b"shout/addr_pre".as_slice(); group.addr_oracles.len()]; + let mut claims: Vec> = group + .addr_oracles + .iter_mut() + .zip(group.active_claimed_sums.iter()) + .zip(labels_active.iter()) + .map(|((oracle, sum), label)| BatchedClaim { + oracle: oracle.as_mut(), + claimed_sum: *sum, + label: *label, + }) + .collect(); + + let (r_addr, per_claim_results) = + run_batched_sumcheck_prover_ds(tr, b"shout/addr_pre_time", step_idx, claims.as_mut_slice())?; + let round_polys = per_claim_results + .iter() + .map(|r| r.round_polys.clone()) + .collect::>(); + (r_addr, round_polys) + }; + + group_proofs.push(ShoutAddrPreGroupProof { + ell_addr, + active_lanes: group.active_lanes, + round_polys, + r_addr, + }); + } + + Ok(ShoutAddrPreBatchProverData { + addr_pre: ShoutAddrPreProof { + claimed_sums, + groups: group_proofs, + }, + decoded: decoded_cols, + }) +} + +pub fn verify_shout_addr_pre_time( + tr: &mut Poseidon2Transcript, + step: &StepInstanceBundle, + mem_proof: &MemSidecarProof, + step_idx: usize, +) -> Result, PiCcsError> { + let proof = &mem_proof.shout_addr_pre; + + if step.lut_insts.is_empty() { + if !proof.claimed_sums.is_empty() || !proof.groups.is_empty() { + return Err(PiCcsError::InvalidInput( + "shout_addr_pre must be empty when there are no Shout instances".into(), + )); + } + return Ok(Vec::new()); + } + + let total_lanes: usize = step.lut_insts.iter().map(|inst| inst.lanes.max(1)).sum(); + if proof.claimed_sums.len() != total_lanes { + return Err(PiCcsError::InvalidInput(format!( + "shout_addr_pre claimed_sums.len()={}, expected total_lanes={}", + proof.claimed_sums.len(), + total_lanes + ))); + } + + // Flatten lane->ell_addr mapping in canonical order so we can validate group membership and + // attach the correct `r_addr` per lane. + let mut lane_ell_addr: Vec = Vec::with_capacity(total_lanes); + let mut required_ell_addrs: std::collections::BTreeSet = std::collections::BTreeSet::new(); + for lut_inst in step.lut_insts.iter() { + neo_memory::addr::validate_shout_bit_addressing(lut_inst)?; + let inst_ell_addr = lut_inst.d * lut_inst.ell; + let inst_ell_addr_u32 = u32::try_from(inst_ell_addr) + .map_err(|_| PiCcsError::InvalidInput("Shout: ell_addr overflows u32".into()))?; + required_ell_addrs.insert(inst_ell_addr_u32); + for _lane_idx in 0..lut_inst.lanes.max(1) { + lane_ell_addr.push(inst_ell_addr_u32); + } + } + if lane_ell_addr.len() != total_lanes { + return Err(PiCcsError::ProtocolError( + "shout addr-pre lane indexing drift (lane_ell_addr)".into(), + )); + } + + // Groups must match the step's required `ell_addr` set and be sorted/unique. + if proof.groups.len() != required_ell_addrs.len() { + return Err(PiCcsError::InvalidInput(format!( + "shout_addr_pre groups.len()={}, expected {} (distinct ell_addr values in step)", + proof.groups.len(), + required_ell_addrs.len() + ))); + } + let required_list: Vec = required_ell_addrs.into_iter().collect(); + for (idx, group) in proof.groups.iter().enumerate() { + let expected_ell_addr = required_list[idx]; + if group.ell_addr != expected_ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "shout_addr_pre groups not sorted or mismatched: groups[{idx}].ell_addr={} but expected {expected_ell_addr}", + group.ell_addr + ))); + } + if group.r_addr.len() != group.ell_addr as usize { + return Err(PiCcsError::InvalidInput(format!( + "shout_addr_pre group ell_addr={} has r_addr.len()={}, expected {}", + group.ell_addr, + group.r_addr.len(), + group.ell_addr + ))); + } + if group.round_polys.len() != group.active_lanes.len() { + return Err(PiCcsError::InvalidInput(format!( + "shout_addr_pre group ell_addr={} round_polys.len()={}, expected active_lanes.len()={}", + group.ell_addr, + group.round_polys.len(), + group.active_lanes.len() + ))); + } + + for (pos, &lane_idx) in group.active_lanes.iter().enumerate() { + let lane_idx_usize = lane_idx as usize; + if lane_idx_usize >= total_lanes { + return Err(PiCcsError::InvalidInput( + "shout_addr_pre active_lanes has index out of range".into(), + )); + } + if lane_ell_addr[lane_idx_usize] != group.ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "shout_addr_pre active_lanes contains lane_idx={} with ell_addr={}, but group ell_addr={}", + lane_idx, lane_ell_addr[lane_idx_usize], group.ell_addr + ))); + } + if pos > 0 && group.active_lanes[pos - 1] >= lane_idx { + return Err(PiCcsError::InvalidInput( + "shout_addr_pre active_lanes must be strictly increasing".into(), + )); + } + } + for (pos, rounds) in group.round_polys.iter().enumerate() { + if rounds.len() != group.ell_addr as usize { + return Err(PiCcsError::InvalidInput(format!( + "shout_addr_pre group ell_addr={} round_polys[{pos}].len()={}, expected {}", + group.ell_addr, + rounds.len(), + group.ell_addr + ))); + } + } + } + + // Bind all claimed sums (all lanes) once. + let labels_all: Vec<&'static [u8]> = vec![b"shout/addr_pre".as_slice(); total_lanes]; + tr.append_message(b"shout/addr_pre_time/step_idx", &(step_idx as u64).to_le_bytes()); + bind_batched_claim_sums( + tr, + b"shout/addr_pre_time/claimed_sums", + &proof.claimed_sums, + &labels_all, + ); + + // Verify each `ell_addr` group independently, collecting per-lane addr-pre finals and + // recording the shared `r_addr` for that group. + let mut lane_is_active = vec![false; total_lanes]; + let mut lane_addr_final = vec![K::ZERO; total_lanes]; + let mut r_addr_by_ell: std::collections::BTreeMap> = std::collections::BTreeMap::new(); + let mut seen_active: std::collections::HashSet = std::collections::HashSet::new(); + + for (group_idx, group) in proof.groups.iter().enumerate() { + tr.append_message(b"shout/addr_pre_time/group_idx", &(group_idx as u64).to_le_bytes()); + tr.append_message( + b"shout/addr_pre_time/group_ell_addr", + &(group.ell_addr as u64).to_le_bytes(), + ); + + if group.active_lanes.is_empty() { + // No active lanes in this group: match prover's deterministic fallback sampling. + tr.append_message(b"shout/addr_pre_time/no_sumcheck", &(step_idx as u64).to_le_bytes()); + tr.append_message( + b"shout/addr_pre_time/no_sumcheck/ell_addr", + &(group.ell_addr as u64).to_le_bytes(), + ); + let r_addr = ts::sample_ext_point( + tr, + b"shout/addr_pre_time/no_sumcheck/r_addr", + b"shout/addr_pre_time/no_sumcheck/r_addr/0", + b"shout/addr_pre_time/no_sumcheck/r_addr/1", + group.ell_addr as usize, + ); + if r_addr != group.r_addr { + return Err(PiCcsError::ProtocolError( + "shout_addr_pre r_addr mismatch: transcript-derived vs proof".into(), + )); + } + r_addr_by_ell.insert(group.ell_addr, r_addr); + continue; + } + + let active_count = group.active_lanes.len(); + let mut active_claimed_sums: Vec = Vec::with_capacity(active_count); + for &lane_idx in group.active_lanes.iter() { + if !seen_active.insert(lane_idx) { + return Err(PiCcsError::InvalidInput( + "shout_addr_pre active_lanes contains duplicates across groups".into(), + )); + } + active_claimed_sums.push( + *proof + .claimed_sums + .get(lane_idx as usize) + .ok_or_else(|| PiCcsError::ProtocolError("shout addr-pre active lane idx drift".into()))?, + ); + } + let labels_active: Vec<&'static [u8]> = vec![b"shout/addr_pre".as_slice(); active_count]; + let degree_bounds = vec![2usize; active_count]; + let (r_addr, finals, ok) = verify_batched_sumcheck_rounds_ds( + tr, + b"shout/addr_pre_time", + step_idx, + &group.round_polys, + &active_claimed_sums, + &labels_active, + °ree_bounds, + ); + if !ok { + return Err(PiCcsError::SumcheckError( + "shout addr-pre batched sumcheck invalid".into(), + )); + } + if r_addr != group.r_addr { + return Err(PiCcsError::ProtocolError( + "shout_addr_pre r_addr mismatch: transcript-derived vs proof".into(), + )); + } + if finals.len() != active_count { + return Err(PiCcsError::ProtocolError(format!( + "shout addr-pre finals.len()={}, expected active_count={active_count}", + finals.len() + ))); + } + + for (pos, &lane_idx) in group.active_lanes.iter().enumerate() { + let lane_idx_usize = lane_idx as usize; + lane_is_active[lane_idx_usize] = true; + lane_addr_final[lane_idx_usize] = finals[pos]; + } + r_addr_by_ell.insert(group.ell_addr, r_addr); + } + + // Build per-lane verify data in canonical order. + let mut out = Vec::with_capacity(total_lanes); + for (lut_inst, inst_ell_addr) in step.lut_insts.iter().map(|inst| (inst, inst.d * inst.ell)) { + let expected_lanes = lut_inst.lanes.max(1); + let inst_ell_addr_u32 = u32::try_from(inst_ell_addr) + .map_err(|_| PiCcsError::InvalidInput("Shout: ell_addr overflows u32".into()))?; + let r_addr = r_addr_by_ell + .get(&inst_ell_addr_u32) + .ok_or_else(|| PiCcsError::ProtocolError("missing shout addr-pre group r_addr".into()))?; + + for _lane_idx in 0..expected_lanes { + let flat_lane_idx = out.len(); + let addr_claim_sum = *proof + .claimed_sums + .get(flat_lane_idx) + .ok_or_else(|| PiCcsError::ProtocolError("shout addr-pre lane index drift".into()))?; + let is_active = *lane_is_active + .get(flat_lane_idx) + .ok_or_else(|| PiCcsError::ProtocolError("shout addr-pre lane idx drift".into()))?; + let addr_final = *lane_addr_final + .get(flat_lane_idx) + .ok_or_else(|| PiCcsError::ProtocolError("shout addr-pre lane idx drift".into()))?; + + let table_eval_at_r_addr = if is_active { + match &lut_inst.table_spec { + None => { + let pow2 = 1usize + .checked_shl(r_addr.len() as u32) + .ok_or_else(|| PiCcsError::InvalidInput("Shout: 2^ell_addr overflow".into()))?; + let mut acc = K::ZERO; + for (i, &v) in lut_inst.table.iter().enumerate().take(pow2) { + let w = neo_memory::mle::chi_at_index(r_addr, i); + acc += K::from(v) * w; + } + acc + } + Some(spec) => spec.eval_table_mle(r_addr)?, + } + } else { + K::ZERO + }; + + out.push(ShoutAddrPreVerifyData { + is_active, + addr_claim_sum, + addr_final: if is_active { addr_final } else { K::ZERO }, + r_addr: r_addr.clone(), + table_eval_at_r_addr, + }); + } + } + if out.len() != total_lanes { + return Err(PiCcsError::ProtocolError("shout addr-pre lane count mismatch".into())); + } + + Ok(out) +} + +pub fn verify_twist_addr_pre_time( + tr: &mut Poseidon2Transcript, + step: &StepInstanceBundle, + mem_proof: &MemSidecarProof, +) -> Result, PiCcsError> { + let mut out = Vec::with_capacity(step.mem_insts.len()); + let proof_offset = step.lut_insts.len(); + + for (idx, mem_inst) in step.mem_insts.iter().enumerate() { + let proof = match mem_proof.proofs.get(proof_offset + idx) { + Some(MemOrLutProof::Twist(p)) => p, + _ => return Err(PiCcsError::InvalidInput("expected Twist proof".into())), + }; + + if proof.addr_pre.claimed_sums.len() != 2 { + return Err(PiCcsError::InvalidInput(format!( + "twist addr_pre claimed_sums.len()={}, expected 2", + proof.addr_pre.claimed_sums.len() + ))); + } + if proof.addr_pre.round_polys.len() != 2 { + return Err(PiCcsError::InvalidInput(format!( + "twist addr_pre round_polys.len()={}, expected 2", + proof.addr_pre.round_polys.len() + ))); + } + if proof.addr_pre.claimed_sums[0] != K::ZERO || proof.addr_pre.claimed_sums[1] != K::ZERO { + return Err(PiCcsError::ProtocolError( + "twist addr_pre claimed_sums mismatch (expected both 0)".into(), + )); + } + + let labels: [&[u8]; 2] = [b"twist/read_addr_pre".as_slice(), b"twist/write_addr_pre".as_slice()]; + let degree_bounds = vec![2usize, 2usize]; + tr.append_message(b"twist/addr_pre_time/claim_idx", &(idx as u64).to_le_bytes()); + bind_batched_claim_sums( + tr, + b"twist/addr_pre_time/claimed_sums", + &proof.addr_pre.claimed_sums, + &labels, + ); + + let (r_addr, finals, ok) = verify_batched_sumcheck_rounds_ds( + tr, + b"twist/addr_pre_time", + idx, + &proof.addr_pre.round_polys, + &proof.addr_pre.claimed_sums, + &labels, + °ree_bounds, + ); + if !ok { + return Err(PiCcsError::SumcheckError( + "twist addr-pre batched sumcheck invalid".into(), + )); + } + if r_addr != proof.addr_pre.r_addr { + return Err(PiCcsError::ProtocolError( + "twist addr_pre r_addr mismatch: transcript-derived vs proof".into(), + )); + } + + let ell_addr = mem_inst.d * mem_inst.ell; + if r_addr.len() != ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "twist addr_pre r_addr.len()={}, expected ell_addr={}", + r_addr.len(), + ell_addr + ))); + } + if finals.len() != 2 { + return Err(PiCcsError::ProtocolError(format!( + "twist addr-pre finals.len()={}, expected 2", + finals.len() + ))); + } + + out.push(TwistAddrPreVerifyData { + r_addr, + read_check_claim_sum: finals[0], + write_check_claim_sum: finals[1], + }); + } + + Ok(out) +} diff --git a/crates/neo-fold/src/memory_sidecar/memory/event_table_context.rs b/crates/neo-fold/src/memory_sidecar/memory/event_table_context.rs new file mode 100644 index 00000000..47b7d126 --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/event_table_context.rs @@ -0,0 +1,148 @@ +use super::*; + +pub(crate) fn build_event_table_shout_context( + params: &NeoParams, + step: &StepWitnessBundle, + ell_n: usize, + r_cycle: &[K], +) -> Result<(K, K, K, Option), PiCcsError> { + let any_event_table_shout = step + .lut_instances + .iter() + .any(|(inst, _wit)| matches!(inst.table_spec, Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. }))); + if any_event_table_shout { + for (idx, (inst, _wit)) in step.lut_instances.iter().enumerate() { + if !matches!(inst.table_spec, Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. })) { + return Err(PiCcsError::InvalidInput(format!( + "event-table Shout mode requires all Shout instances to use RiscvOpcodeEventTablePacked (lut_idx={idx})" + ))); + } + } + } + + let (event_alpha, event_beta, event_gamma) = if any_event_table_shout { + if r_cycle.len() < 3 { + return Err(PiCcsError::InvalidInput("event-table Shout requires ell_n >= 3".into())); + } + (r_cycle[0], r_cycle[1], r_cycle[2]) + } else { + (K::ZERO, K::ZERO, K::ZERO) + }; + + let shout_event_trace_hash: Option = if any_event_table_shout { + let m_in = step.mcs.0.m_in; + if m_in != 5 { + return Err(PiCcsError::InvalidInput(format!( + "event-table Shout trace linkage expects m_in=5 (got {m_in})" + ))); + } + let trace = Rv32TraceLayout::new(); + let m = step.mcs.1.Z.cols(); + let t_len = step + .mem_instances + .first() + .map(|(inst, _wit)| inst.steps) + .or_else(|| { + let w = m.checked_sub(m_in)?; + if trace.cols == 0 || w % trace.cols != 0 { + return None; + } + Some(w / trace.cols) + }) + .ok_or_else(|| PiCcsError::InvalidInput("event-table Shout trace linkage missing t_len".into()))?; + if t_len == 0 { + return Err(PiCcsError::InvalidInput( + "event-table Shout trace linkage requires t_len >= 1".into(), + )); + } + let pow2_cycle = 1usize + .checked_shl(ell_n as u32) + .ok_or_else(|| PiCcsError::InvalidInput("event-table Shout: 2^ell_n overflow".into()))?; + if m_in + .checked_add(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("event-table Shout: m_in + t_len overflow".into()))? + > pow2_cycle + { + return Err(PiCcsError::InvalidInput(format!( + "event-table Shout: trace time rows out of range: m_in({m_in}) + t_len({t_len}) > 2^ell_n({pow2_cycle})" + ))); + } + + let d = neo_math::D; + let z = &step.mcs.1.Z; + if z.rows() != d { + return Err(PiCcsError::InvalidInput(format!( + "event-table Shout: CPU witness Z.rows()={} != D={d}", + z.rows() + ))); + } + if z.cols() != m { + return Err(PiCcsError::ProtocolError( + "event-table Shout: CPU witness width drift".into(), + )); + } + + let b_k = K::from(F::from_u64(params.b as u64)); + let mut pow_b = Vec::with_capacity(d); + let mut cur = K::ONE; + for _ in 0..d { + pow_b.push(cur); + cur *= b_k; + } + let decode_idx = |idx: usize| -> Result { + if idx >= m { + return Err(PiCcsError::InvalidInput(format!( + "event-table Shout: z idx out of range (idx={idx}, m={m})" + ))); + } + let mut acc = K::ZERO; + for rho in 0..d { + acc += pow_b[rho] * K::from(z[(rho, idx)]); + } + Ok(acc) + }; + + let trace_base = m_in; + let shout_col = |col_id: usize, j: usize| -> Result { + let col_offset = col_id + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace col_id * t_len overflow".into()))?; + let idx = trace_base + .checked_add(col_offset) + .and_then(|x| x.checked_add(j)) + .ok_or_else(|| PiCcsError::InvalidInput("trace z idx overflow".into()))?; + decode_idx(idx) + }; + + let mut gate_entries: Vec<(usize, K)> = Vec::new(); + let mut hash_entries: Vec<(usize, K)> = Vec::new(); + for j in 0..t_len { + let t = m_in + j; + let gate = shout_col(trace.shout_has_lookup, j)?; + if gate == K::ZERO { + continue; + } + gate_entries.push((t, gate)); + + let val = shout_col(trace.shout_val, j)?; + let lhs = shout_col(trace.shout_lhs, j)?; + let rhs = shout_col(trace.shout_rhs, j)?; + let hash = K::ONE + event_alpha * val + event_beta * lhs + event_gamma * rhs; + if hash != K::ZERO { + hash_entries.push((t, hash)); + } + } + + let gate = SparseIdxVec::from_entries(pow2_cycle, gate_entries); + let hash = SparseIdxVec::from_entries(pow2_cycle, hash_entries); + let (oracle, claim) = ShoutValueOracleSparse::new(r_cycle, gate, hash); + Some(RouteAShoutEventTraceHashOracle { + oracle: Box::new(oracle), + claim, + }) + } else { + None + }; + + Ok((event_alpha, event_beta, event_gamma, shout_event_trace_hash)) +} diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs new file mode 100644 index 00000000..e57012f1 --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs @@ -0,0 +1,945 @@ +use super::*; + +pub(crate) fn width_lookup_bus_val_cols_witness( + step: &StepWitnessBundle, + t_len: usize, +) -> Result, PiCcsError> { + let width = Rv32WidthSidecarLayout::new(); + let width_cols = rv32_width_lookup_backed_cols(&width); + let mut width_bus_col_by_col: BTreeMap = BTreeMap::new(); + let m_in = step.mcs.0.m_in; + let bus = build_bus_layout_for_step_witness(step, t_len)?; + if bus.shout_cols.len() != step.lut_instances.len() { + return Err(PiCcsError::ProtocolError( + "W3(shared): bus shout lane count drift while resolving width lookup columns".into(), + )); + } + let bus_base_delta = bus + .bus_base + .checked_sub(m_in) + .ok_or_else(|| PiCcsError::ProtocolError("W3(shared): bus_base underflow".into()))?; + if bus_base_delta % t_len != 0 { + return Err(PiCcsError::ProtocolError(format!( + "W3(shared): bus_base alignment mismatch (bus_base_delta={bus_base_delta}, t_len={t_len})" + ))); + } + let bus_col_offset = bus_base_delta / t_len; + for (lut_idx, (inst, _)) in step.lut_instances.iter().enumerate() { + if !rv32_is_width_lookup_table_id(inst.table_id) { + continue; + } + let width_col_id = width_cols + .iter() + .copied() + .find(|&col_id| rv32_width_lookup_table_id_for_col(col_id) == inst.table_id) + .ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "W3(shared): width lookup table_id={} does not map to a known width column", + inst.table_id + )) + })?; + let inst_cols = bus + .shout_cols + .get(lut_idx) + .ok_or_else(|| PiCcsError::ProtocolError("W3(shared): missing shout cols for width lookup table".into()))?; + let lane0 = inst_cols.lanes.get(0).ok_or_else(|| { + PiCcsError::ProtocolError("W3(shared): expected one shout lane for width lookup table".into()) + })?; + width_bus_col_by_col.insert(width_col_id, bus_col_offset + lane0.primary_val()); + } + let mut out = Vec::with_capacity(width_cols.len()); + for &col_id in width_cols.iter() { + let bus_col = width_bus_col_by_col.get(&col_id).copied().ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "W3(shared): missing width lookup bus val column for width col_id={col_id}" + )) + })?; + out.push(bus_col); + } + Ok(out) +} + +pub(crate) fn build_route_a_width_time_claims( + params: &NeoParams, + step: &StepWitnessBundle, + r_cycle: &[K], +) -> Result { + if !width_stage_required_for_step_witness(step) { + return Ok((None, None, None, None, None)); + } + let trace = Rv32TraceLayout::new(); + let width = Rv32WidthSidecarLayout::new(); + let decode = Rv32DecodeSidecarLayout::new(); + let m_in = step.mcs.0.m_in; + let ell_n = r_cycle.len(); + let t_len = infer_rv32_trace_t_len_for_wb_wp(step, &trace)?; + if t_len == 0 { + return Err(PiCcsError::InvalidInput("W3: t_len must be >= 1".into())); + } + + let main_col_ids = [ + trace.active, + trace.instr_word, + trace.rd_val, + trace.ram_rv, + trace.ram_wv, + trace.rs2_val, + ]; + let main_decoded = decode_trace_col_values_batch(params, step, t_len, &main_col_ids)?; + let width_col_ids = rv32_width_lookup_backed_cols(&width); + let width_decoded: BTreeMap> = { + let width_bus_abs_cols = width_lookup_bus_val_cols_witness(step, t_len)?; + let bus = build_bus_layout_for_step_witness(step, t_len)?; + let bus_base_delta = bus + .bus_base + .checked_sub(m_in) + .ok_or_else(|| PiCcsError::ProtocolError("W3(shared): bus_base underflow".into()))?; + if bus_base_delta % t_len != 0 { + return Err(PiCcsError::ProtocolError(format!( + "W3(shared): bus_base alignment mismatch (bus_base_delta={bus_base_delta}, t_len={t_len})" + ))); + } + let bus_col_offset = bus_base_delta / t_len; + let mut width_bus_val_cols = Vec::with_capacity(width_bus_abs_cols.len()); + for abs_col in width_bus_abs_cols.iter().copied() { + let local_col = abs_col.checked_sub(bus_col_offset).ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "W3(shared): width lookup bus column underflow (abs_col={abs_col}, bus_col_offset={bus_col_offset})" + )) + })?; + if local_col >= bus.bus_cols { + return Err(PiCcsError::ProtocolError(format!( + "W3(shared): width lookup bus column out of range (local_col={local_col}, bus_cols={})", + bus.bus_cols + ))); + } + width_bus_val_cols.push(local_col); + } + let lookup_vals = decode_lookup_backed_col_values_batch( + params, + bus.bus_base, + t_len, + &step.mcs.1.Z, + bus.bus_cols, + &width_bus_val_cols, + )?; + let mut by_col = BTreeMap::>::new(); + for (idx, &col_id) in width_col_ids.iter().enumerate() { + let bus_col_id = width_bus_val_cols[idx]; + let vals = lookup_vals.get(&bus_col_id).ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "W3(shared): missing decoded lookup values for bus_col={bus_col_id}" + )) + })?; + by_col.insert(col_id, vals.clone()); + } + by_col + }; + let decode_col_ids: Vec = core::iter::once(decode.op_load) + .chain(core::iter::once(decode.op_store)) + .chain(core::iter::once(decode.rd_has_write)) + .chain(core::iter::once(decode.ram_has_read)) + .chain(core::iter::once(decode.ram_has_write)) + .chain(decode.funct3_is.iter().copied()) + .collect(); + let decode_decoded = { + let instr_vals = main_decoded + .get(&trace.instr_word) + .ok_or_else(|| PiCcsError::ProtocolError("W3(shared): missing instr_word decode column".into()))?; + let active_vals = main_decoded + .get(&trace.active) + .ok_or_else(|| PiCcsError::ProtocolError("W3(shared): missing active decode column".into()))?; + if instr_vals.len() != t_len || active_vals.len() != t_len { + return Err(PiCcsError::ProtocolError(format!( + "W3(shared): decoded CPU column lengths drift (instr={}, active={}, t_len={t_len})", + instr_vals.len(), + active_vals.len() + ))); + } + let mut decoded = BTreeMap::>::new(); + for &col_id in decode_col_ids.iter() { + decoded.insert(col_id, Vec::with_capacity(t_len)); + } + for j in 0..t_len { + let instr_word = decode_k_to_u32(instr_vals[j], "W3(shared)/instr_word")?; + let active = active_vals[j] != K::ZERO; + let mut row = rv32_decode_lookup_backed_row_from_instr_word(&decode, instr_word, active); + if !active { + row.fill(F::ZERO); + } + for &col_id in decode_col_ids.iter() { + decoded + .get_mut(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError("W3(shared): decode map build failed".into()))? + .push(K::from(row[col_id])); + } + } + decoded + }; + + let mut main_sparse = BTreeMap::>::new(); + for &col_id in main_col_ids.iter() { + let vals = main_decoded + .get(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3 missing main decoded column {col_id}")))?; + main_sparse.insert(col_id, sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + let mut width_sparse = BTreeMap::>::new(); + for &col_id in width_col_ids.iter() { + let vals = width_decoded + .get(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3 missing width decoded column {col_id}")))?; + width_sparse.insert(col_id, sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + let mut decode_sparse = BTreeMap::>::new(); + for &col_id in decode_col_ids.iter() { + let vals = decode_decoded + .get(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3 missing decode decoded column {col_id}")))?; + decode_sparse.insert(col_id, sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + + let main_col = |col_id: usize| -> Result, PiCcsError> { + main_sparse + .get(&col_id) + .cloned() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3 missing main sparse column {col_id}"))) + }; + let width_col = |col_id: usize| -> Result, PiCcsError> { + width_sparse + .get(&col_id) + .cloned() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3 missing width sparse column {col_id}"))) + }; + let decode_col = |col_id: usize| -> Result, PiCcsError> { + decode_sparse + .get(&col_id) + .cloned() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3 missing decode sparse column {col_id}"))) + }; + + let bitness_cols: Vec = width + .ram_rv_low_bit + .iter() + .chain(width.rs2_low_bit.iter()) + .copied() + .collect(); + let mut bitness_sparse = Vec::with_capacity(bitness_cols.len()); + for &col_id in bitness_cols.iter() { + bitness_sparse.push(width_col(col_id)?); + } + let bitness_weights = w3_bitness_weight_vector(r_cycle, bitness_cols.len()); + let bitness_oracle = FormulaOracleSparseTime::new( + bitness_sparse, + 3, + r_cycle, + Box::new(move |vals: &[K]| { + let mut weighted = K::ZERO; + for (b, w) in vals.iter().zip(bitness_weights.iter()) { + weighted += *w * *b * (*b - K::ONE); + } + weighted + }), + ); + + let mut quiescence_sparse = Vec::with_capacity(1 + width.cols); + quiescence_sparse.push(main_col(trace.active)?); + for &col_id in width_col_ids.iter() { + quiescence_sparse.push(width_col(col_id)?); + } + let quiescence_weights = w3_quiescence_weight_vector(r_cycle, width.cols); + let quiescence_oracle = FormulaOracleSparseTime::new( + quiescence_sparse, + 3, + r_cycle, + Box::new(move |vals: &[K]| { + let active = vals[0]; + let mut weighted = K::ZERO; + for (i, w) in quiescence_weights.iter().enumerate() { + weighted += *w * vals[1 + i]; + } + (K::ONE - active) * weighted + }), + ); + + let mut load_sparse = Vec::with_capacity(31); + load_sparse.push(main_col(trace.rd_val)?); + load_sparse.push(main_col(trace.ram_rv)?); + load_sparse.push(decode_col(decode.rd_has_write)?); + load_sparse.push(decode_col(decode.ram_has_read)?); + load_sparse.push(decode_col(decode.op_load)?); + load_sparse.push(decode_col(decode.funct3_is[0])?); + load_sparse.push(decode_col(decode.funct3_is[1])?); + load_sparse.push(decode_col(decode.funct3_is[2])?); + load_sparse.push(decode_col(decode.funct3_is[4])?); + load_sparse.push(decode_col(decode.funct3_is[5])?); + load_sparse.push(width_col(width.ram_rv_q16)?); + for &col_id in width.ram_rv_low_bit.iter() { + load_sparse.push(width_col(col_id)?); + } + let load_weights = w3_load_weight_vector(r_cycle, 16); + let load_oracle = FormulaOracleSparseTime::new( + load_sparse, + 4, + r_cycle, + Box::new(move |vals: &[K]| { + let rd_val = vals[0]; + let ram_rv = vals[1]; + let rd_has_write = vals[2]; + let ram_has_read = vals[3]; + let op_load = vals[4]; + let funct3_is_0 = vals[5]; + let funct3_is_1 = vals[6]; + let funct3_is_2 = vals[7]; + let funct3_is_4 = vals[8]; + let funct3_is_5 = vals[9]; + let ram_rv_q16 = vals[10]; + let load_flags = [ + op_load * funct3_is_0, + op_load * funct3_is_4, + op_load * funct3_is_1, + op_load * funct3_is_5, + op_load * funct3_is_2, + ]; + let mut ram_rv_low_bits = [K::ZERO; 16]; + ram_rv_low_bits.copy_from_slice(&vals[11..27]); + let residuals = w3_load_semantics_residuals( + rd_val, + ram_rv, + rd_has_write, + ram_has_read, + load_flags, + ram_rv_q16, + ram_rv_low_bits, + ); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(load_weights.iter()) { + weighted += *w * *r; + } + weighted + }), + ); + + let mut store_sparse = Vec::with_capacity(45); + store_sparse.push(main_col(trace.ram_wv)?); + store_sparse.push(main_col(trace.ram_rv)?); + store_sparse.push(main_col(trace.rs2_val)?); + store_sparse.push(decode_col(decode.rd_has_write)?); + store_sparse.push(decode_col(decode.ram_has_read)?); + store_sparse.push(decode_col(decode.ram_has_write)?); + store_sparse.push(decode_col(decode.op_store)?); + store_sparse.push(decode_col(decode.funct3_is[0])?); + store_sparse.push(decode_col(decode.funct3_is[1])?); + store_sparse.push(decode_col(decode.funct3_is[2])?); + store_sparse.push(width_col(width.rs2_q16)?); + for &col_id in width.ram_rv_low_bit.iter() { + store_sparse.push(width_col(col_id)?); + } + for &col_id in width.rs2_low_bit.iter() { + store_sparse.push(width_col(col_id)?); + } + let store_weights = w3_store_weight_vector(r_cycle, 12); + let store_oracle = FormulaOracleSparseTime::new( + store_sparse, + 4, + r_cycle, + Box::new(move |vals: &[K]| { + let ram_wv = vals[0]; + let ram_rv = vals[1]; + let rs2_val = vals[2]; + let rd_has_write = vals[3]; + let ram_has_read = vals[4]; + let ram_has_write = vals[5]; + let op_store = vals[6]; + let funct3_is_0 = vals[7]; + let funct3_is_1 = vals[8]; + let funct3_is_2 = vals[9]; + let rs2_q16 = vals[10]; + let store_flags = [op_store * funct3_is_0, op_store * funct3_is_1, op_store * funct3_is_2]; + let mut ram_rv_low_bits = [K::ZERO; 16]; + ram_rv_low_bits.copy_from_slice(&vals[11..27]); + let mut rs2_low_bits = [K::ZERO; 16]; + rs2_low_bits.copy_from_slice(&vals[27..43]); + let residuals = w3_store_semantics_residuals( + ram_wv, + ram_rv, + rs2_val, + rd_has_write, + ram_has_read, + ram_has_write, + store_flags, + rs2_q16, + ram_rv_low_bits, + rs2_low_bits, + ); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(store_weights.iter()) { + weighted += *w * *r; + } + weighted + }), + ); + + Ok(( + Some((Box::new(bitness_oracle), K::ZERO)), + Some((Box::new(quiescence_oracle), K::ZERO)), + None, + Some((Box::new(load_oracle), K::ZERO)), + Some((Box::new(store_oracle), K::ZERO)), + )) +} + +type ControlTimeClaims = ( + Option<(Box, K)>, + Option<(Box, K)>, + Option<(Box, K)>, + Option<(Box, K)>, +); + +pub(crate) fn build_route_a_control_time_claims( + params: &NeoParams, + step: &StepWitnessBundle, + r_cycle: &[K], +) -> Result { + if !control_stage_required_for_step_witness(step) { + return Ok((None, None, None, None)); + } + let trace = Rv32TraceLayout::new(); + let decode = Rv32DecodeSidecarLayout::new(); + let m_in = step.mcs.0.m_in; + let ell_n = r_cycle.len(); + let t_len = infer_rv32_trace_t_len_for_wb_wp(step, &trace)?; + if t_len == 0 { + return Err(PiCcsError::InvalidInput("control stage: t_len must be >= 1".into())); + } + + let main_col_ids = vec![ + trace.active, + trace.instr_word, + trace.pc_before, + trace.pc_after, + trace.rs1_val, + trace.rd_val, + trace.shout_val, + trace.jalr_drop_bit, + ]; + let decode_col_ids = vec![ + decode.op_lui, + decode.op_auipc, + decode.op_jal, + decode.op_jalr, + decode.op_branch, + decode.op_load, + decode.op_store, + decode.op_alu_imm, + decode.op_alu_reg, + decode.op_misc_mem, + decode.op_system, + decode.op_amo, + decode.op_lui_write, + decode.op_auipc_write, + decode.op_jal_write, + decode.op_jalr_write, + decode.rd_is_zero, + decode.imm_i, + decode.imm_b, + decode.imm_j, + decode.funct3_is[6], + decode.funct3_is[7], + decode.funct3_bit[0], + decode.funct3_bit[1], + decode.funct3_bit[2], + decode.rs1_bit[0], + decode.rs1_bit[1], + decode.rs1_bit[2], + decode.rs1_bit[3], + decode.rs1_bit[4], + decode.rs2_bit[0], + decode.rs2_bit[1], + decode.rs2_bit[2], + decode.rs2_bit[3], + decode.rs2_bit[4], + decode.funct7_bit[0], + decode.funct7_bit[1], + decode.funct7_bit[2], + decode.funct7_bit[3], + decode.funct7_bit[4], + decode.funct7_bit[5], + decode.funct7_bit[6], + ]; + + let main_decoded = decode_trace_col_values_batch(params, step, t_len, &main_col_ids)?; + let decode_decoded = { + let instr_vals = main_decoded + .get(&trace.instr_word) + .ok_or_else(|| PiCcsError::ProtocolError("control(shared): missing instr_word decode column".into()))?; + let active_vals = main_decoded + .get(&trace.active) + .ok_or_else(|| PiCcsError::ProtocolError("control(shared): missing active decode column".into()))?; + if instr_vals.len() != t_len || active_vals.len() != t_len { + return Err(PiCcsError::ProtocolError(format!( + "control(shared): decoded CPU column lengths drift (instr={}, active={}, t_len={t_len})", + instr_vals.len(), + active_vals.len() + ))); + } + let mut decoded = BTreeMap::>::new(); + for &col_id in decode_col_ids.iter() { + decoded.insert(col_id, Vec::with_capacity(t_len)); + } + for j in 0..t_len { + let instr_word = decode_k_to_u32(instr_vals[j], "control(shared)/instr_word")?; + let active = active_vals[j] != K::ZERO; + let mut row = rv32_decode_lookup_backed_row_from_instr_word(&decode, instr_word, active); + if !active { + row.fill(F::ZERO); + } + let rd_has_write = if active { + K::ONE - K::from(row[decode.rd_is_zero]) + } else { + K::ZERO + }; + let op_lui = K::from(row[decode.op_lui]); + let op_auipc = K::from(row[decode.op_auipc]); + let op_jal = K::from(row[decode.op_jal]); + let op_jalr = K::from(row[decode.op_jalr]); + for &col_id in decode_col_ids.iter() { + let val = match col_id { + c if c == decode.op_lui_write => op_lui * rd_has_write, + c if c == decode.op_auipc_write => op_auipc * rd_has_write, + c if c == decode.op_jal_write => op_jal * rd_has_write, + c if c == decode.op_jalr_write => op_jalr * rd_has_write, + _ => K::from(row[col_id]), + }; + decoded + .get_mut(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError("control(shared): decode map build failed".into()))? + .push(val); + } + } + decoded + }; + + let mut main_sparse = BTreeMap::>::new(); + for &col_id in main_col_ids.iter() { + let vals = main_decoded + .get(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("control stage missing main decoded column {col_id}")))?; + main_sparse.insert(col_id, sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + let mut decode_sparse = BTreeMap::>::new(); + for &col_id in decode_col_ids.iter() { + let vals = decode_decoded.get(&col_id).ok_or_else(|| { + PiCcsError::ProtocolError(format!("control stage missing decode decoded column {col_id}")) + })?; + decode_sparse.insert(col_id, sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + + let main_col = |col_id: usize| -> Result, PiCcsError> { + main_sparse + .get(&col_id) + .cloned() + .ok_or_else(|| PiCcsError::ProtocolError(format!("control stage missing main sparse col {col_id}"))) + }; + let decode_col = |col_id: usize| -> Result, PiCcsError> { + decode_sparse + .get(&col_id) + .cloned() + .ok_or_else(|| PiCcsError::ProtocolError(format!("control stage missing decode sparse col {col_id}"))) + }; + + let linear_sparse = vec![ + main_col(trace.pc_before)?, + main_col(trace.pc_after)?, + decode_col(decode.op_lui)?, + decode_col(decode.op_auipc)?, + decode_col(decode.op_load)?, + decode_col(decode.op_store)?, + decode_col(decode.op_alu_imm)?, + decode_col(decode.op_alu_reg)?, + decode_col(decode.op_misc_mem)?, + decode_col(decode.op_system)?, + decode_col(decode.op_amo)?, + ]; + let linear_weights = control_next_pc_linear_weight_vector(r_cycle, 1); + let linear_oracle = FormulaOracleSparseTime::new( + linear_sparse, + 3, + r_cycle, + Box::new(move |vals: &[K]| { + let residual = control_next_pc_linear_residual( + vals[0], vals[1], vals[2], vals[3], vals[4], vals[5], vals[6], vals[7], vals[8], vals[9], vals[10], + ); + linear_weights[0] * residual + }), + ); + + let control_sparse = vec![ + main_col(trace.active)?, + main_col(trace.pc_before)?, + main_col(trace.pc_after)?, + main_col(trace.rs1_val)?, + main_col(trace.jalr_drop_bit)?, + main_col(trace.shout_val)?, + decode_col(decode.funct3_bit[0])?, + decode_col(decode.op_jal)?, + decode_col(decode.op_jalr)?, + decode_col(decode.op_branch)?, + decode_col(decode.imm_i)?, + decode_col(decode.imm_b)?, + decode_col(decode.imm_j)?, + ]; + let control_weights = control_next_pc_control_weight_vector(r_cycle, 5); + let control_oracle = FormulaOracleSparseTime::new( + control_sparse, + 5, + r_cycle, + Box::new(move |vals: &[K]| { + let residuals = control_next_pc_control_residuals( + vals[0], vals[1], vals[2], vals[3], vals[4], vals[10], vals[11], vals[12], vals[7], vals[8], vals[9], + vals[5], vals[6], + ); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(control_weights.iter()) { + weighted += *w * *r; + } + weighted + }), + ); + + let branch_sparse = vec![ + decode_col(decode.op_branch)?, + main_col(trace.shout_val)?, + decode_col(decode.funct3_bit[0])?, + decode_col(decode.funct3_bit[1])?, + decode_col(decode.funct3_bit[2])?, + decode_col(decode.funct3_is[6])?, + decode_col(decode.funct3_is[7])?, + ]; + let branch_weights = control_branch_semantics_weight_vector(r_cycle, 3); + let branch_oracle = FormulaOracleSparseTime::new( + branch_sparse, + 4, + r_cycle, + Box::new(move |vals: &[K]| { + let residuals = + control_branch_semantics_residuals(vals[0], vals[1], vals[2], vals[3], vals[4], vals[5], vals[6]); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(branch_weights.iter()) { + weighted += *w * *r; + } + weighted + }), + ); + + let mut write_sparse = vec![ + main_col(trace.rd_val)?, + main_col(trace.pc_before)?, + decode_col(decode.op_lui)?, + decode_col(decode.op_auipc)?, + decode_col(decode.op_jal)?, + decode_col(decode.op_jalr)?, + decode_col(decode.rd_is_zero)?, + decode_col(decode.funct3_bit[0])?, + decode_col(decode.funct3_bit[1])?, + decode_col(decode.funct3_bit[2])?, + ]; + for &col_id in decode.rs1_bit.iter() { + write_sparse.push(decode_col(col_id)?); + } + for &col_id in decode.rs2_bit.iter() { + write_sparse.push(decode_col(col_id)?); + } + for &col_id in decode.funct7_bit.iter() { + write_sparse.push(decode_col(col_id)?); + } + let write_weights = control_writeback_weight_vector(r_cycle, 4); + let write_oracle = FormulaOracleSparseTime::new( + write_sparse, + 4, + r_cycle, + Box::new(move |vals: &[K]| { + let rd_val = vals[0]; + let pc_before = vals[1]; + let op_lui = vals[2]; + let op_auipc = vals[3]; + let op_jal = vals[4]; + let op_jalr = vals[5]; + let rd_is_zero = vals[6]; + let op_lui_write = op_lui * (K::ONE - rd_is_zero); + let op_auipc_write = op_auipc * (K::ONE - rd_is_zero); + let op_jal_write = op_jal * (K::ONE - rd_is_zero); + let op_jalr_write = op_jalr * (K::ONE - rd_is_zero); + let funct3_bits = [vals[7], vals[8], vals[9]]; + let rs1_bits = [vals[10], vals[11], vals[12], vals[13], vals[14]]; + let rs2_bits = [vals[15], vals[16], vals[17], vals[18], vals[19]]; + let funct7_bits = [vals[20], vals[21], vals[22], vals[23], vals[24], vals[25], vals[26]]; + let imm_u = control_imm_u_from_bits(funct3_bits, rs1_bits, rs2_bits, funct7_bits); + let residuals = control_writeback_residuals( + rd_val, + pc_before, + imm_u, + op_lui_write, + op_auipc_write, + op_jal_write, + op_jalr_write, + ); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(write_weights.iter()) { + weighted += *w * *r; + } + weighted + }), + ); + + Ok(( + Some((Box::new(linear_oracle), K::ZERO)), + Some((Box::new(control_oracle), K::ZERO)), + Some((Box::new(branch_oracle), K::ZERO)), + Some((Box::new(write_oracle), K::ZERO)), + )) +} + +pub(crate) fn emit_route_a_wb_wp_me_claims( + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s: &CcsStructure, + step: &StepWitnessBundle, + r_time: &[K], +) -> Result<(Vec>, Vec>), PiCcsError> { + if !wb_wp_required_for_step_witness(step) { + return Ok((Vec::new(), Vec::new())); + } + + let trace = Rv32TraceLayout::new(); + let t_len = infer_rv32_trace_t_len_for_wb_wp(step, &trace)?; + let m_in = step.mcs.0.m_in; + let core_t = s.t(); + let (mcs_inst, mcs_wit) = &step.mcs; + + let wb_cols = rv32_trace_wb_columns(&trace); + let mut wb_claims = ts::emit_me_claims_for_mats( + tr, + b"cpu/me_digest_wb_time", + params, + s, + core::slice::from_ref(&mcs_inst.c), + core::slice::from_ref(&mcs_wit.Z), + r_time, + m_in, + )?; + if wb_claims.len() != 1 { + return Err(PiCcsError::ProtocolError(format!( + "WB expects exactly one CPU ME claim at r_time, got {}", + wb_claims.len() + ))); + } + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + m_in, + &wb_cols, + core_t, + &mcs_wit.Z, + &mut wb_claims[0], + )?; + + let mut wp_cols = rv32_trace_wp_opening_columns(&trace); + if control_stage_required_for_step_witness(step) { + wp_cols.extend(rv32_trace_control_extra_opening_columns(&trace)); + } + if decode_stage_required_for_step_witness(step) { + let decode_layout = Rv32DecodeSidecarLayout::new(); + let (_decode_open_cols, decode_lut_indices) = resolve_shared_decode_lookup_lut_indices(step, &decode_layout)?; + let bus = build_bus_layout_for_step_witness(step, t_len)?; + if bus.shout_cols.len() != step.lut_instances.len() { + return Err(PiCcsError::ProtocolError( + "W2(shared): bus layout shout lane count drift".into(), + )); + } + let bus_base_delta = bus + .bus_base + .checked_sub(m_in) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): bus_base underflow".into()))?; + if bus_base_delta % t_len != 0 { + return Err(PiCcsError::ProtocolError(format!( + "W2(shared): bus_base alignment mismatch (bus_base_delta={}, t_len={t_len})", + bus_base_delta + ))); + } + let bus_col_offset = bus_base_delta / t_len; + for &lut_idx in decode_lut_indices.iter() { + let inst_cols = bus.shout_cols.get(lut_idx).ok_or_else(|| { + PiCcsError::ProtocolError("W2(shared): missing shout cols for decode lookup table".into()) + })?; + let lane0 = inst_cols.lanes.get(0).ok_or_else(|| { + PiCcsError::ProtocolError("W2(shared): expected one shout lane for decode lookup table".into()) + })?; + wp_cols.push(bus_col_offset + lane0.primary_val()); + } + } + if width_stage_required_for_step_witness(step) { + wp_cols.extend(width_lookup_bus_val_cols_witness(step, t_len)?); + } + let mut wp_claims = ts::emit_me_claims_for_mats( + tr, + b"cpu/me_digest_wp_time", + params, + s, + core::slice::from_ref(&mcs_inst.c), + core::slice::from_ref(&mcs_wit.Z), + r_time, + m_in, + )?; + if wp_claims.len() != 1 { + return Err(PiCcsError::ProtocolError(format!( + "WP expects exactly one CPU ME claim at r_time, got {}", + wp_claims.len() + ))); + } + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + m_in, + &wp_cols, + core_t, + &mcs_wit.Z, + &mut wp_claims[0], + )?; + Ok((wb_claims, wp_claims)) +} + +pub(crate) fn verify_route_a_wb_wp_terminals( + core_t: usize, + step: &StepInstanceBundle, + r_time: &[K], + r_cycle: &[K], + batched_final_values: &[K], + claim_plan: &RouteATimeClaimPlan, + mem_proof: &MemSidecarProof, +) -> Result<(), PiCcsError> { + let trace = Rv32TraceLayout::new(); + + if let Some(claim_idx) = claim_plan.wb_bool { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "wb/booleanity claim index out of range".into(), + )); + } + if mem_proof.wb_me_claims.len() != 1 { + return Err(PiCcsError::ProtocolError(format!( + "WB expects exactly one ME claim at r_time (got {})", + mem_proof.wb_me_claims.len() + ))); + } + let me = &mem_proof.wb_me_claims[0]; + if me.r.as_slice() != r_time { + return Err(PiCcsError::ProtocolError( + "WB ME claim r mismatch (expected r_time)".into(), + )); + } + if me.c != step.mcs_inst.c { + return Err(PiCcsError::ProtocolError("WB ME claim commitment mismatch".into())); + } + if me.m_in != step.mcs_inst.m_in { + return Err(PiCcsError::ProtocolError("WB ME claim m_in mismatch".into())); + } + + let wb_bool_cols = rv32_trace_wb_columns(&trace); + let need = core_t + .checked_add(wb_bool_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("WB opening count overflow".into()))?; + if me.y_scalars.len() != need { + return Err(PiCcsError::ProtocolError(format!( + "WB ME opening length mismatch (got {}, expected {need})", + me.y_scalars.len() + ))); + } + + let wb_bool_open = &me.y_scalars[core_t..]; + let wb_weights = wb_weight_vector(r_cycle, wb_bool_cols.len()); + let mut wb_weighted_bitness = K::ZERO; + for (&b, &w) in wb_bool_open.iter().zip(wb_weights.iter()) { + wb_weighted_bitness += w * b * (b - K::ONE); + } + + let expected_terminal = eq_points(r_time, r_cycle) * wb_weighted_bitness; + let observed_terminal = batched_final_values[claim_idx]; + if observed_terminal != expected_terminal { + return Err(PiCcsError::ProtocolError( + "wb/booleanity terminal value mismatch".into(), + )); + } + } else if !mem_proof.wb_me_claims.is_empty() { + return Err(PiCcsError::ProtocolError( + "unexpected WB ME claims: wb/booleanity stage is not enabled".into(), + )); + } + + if let Some(claim_idx) = claim_plan.wp_quiescence { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "wp/quiescence claim index out of range".into(), + )); + } + if mem_proof.wp_me_claims.len() != 1 { + return Err(PiCcsError::ProtocolError(format!( + "WP expects exactly one ME claim at r_time (got {})", + mem_proof.wp_me_claims.len() + ))); + } + let me = &mem_proof.wp_me_claims[0]; + if me.r.as_slice() != r_time { + return Err(PiCcsError::ProtocolError( + "WP ME claim r mismatch (expected r_time)".into(), + )); + } + if me.c != step.mcs_inst.c { + return Err(PiCcsError::ProtocolError("WP ME claim commitment mismatch".into())); + } + if me.m_in != step.mcs_inst.m_in { + return Err(PiCcsError::ProtocolError("WP ME claim m_in mismatch".into())); + } + + let wp_open_cols = rv32_trace_wp_opening_columns(&trace); + let need_min = core_t + .checked_add(wp_open_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("WP opening count overflow".into()))?; + if me.y_scalars.len() < need_min { + return Err(PiCcsError::ProtocolError(format!( + "WP ME opening length mismatch (got {}, expected at least {need_min})", + me.y_scalars.len() + ))); + } + + let active_open = me + .y_scalars + .get(core_t) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("WP missing active opening".into()))?; + let wp_open_end = core_t + .checked_add(wp_open_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("WP opening end overflow".into()))?; + let wp_open = &me.y_scalars[(core_t + 1)..wp_open_end]; + let wp_weights = wp_weight_vector(r_cycle, wp_open.len()); + let mut wp_weighted_sum = K::ZERO; + for (&v, &w) in wp_open.iter().zip(wp_weights.iter()) { + wp_weighted_sum += w * v; + } + let expected_terminal = eq_points(r_time, r_cycle) * (K::ONE - active_open) * wp_weighted_sum; + let observed_terminal = batched_final_values[claim_idx]; + if observed_terminal != expected_terminal { + return Err(PiCcsError::ProtocolError( + "wp/quiescence terminal value mismatch".into(), + )); + } + } else if !mem_proof.wp_me_claims.is_empty() { + return Err(PiCcsError::ProtocolError( + "unexpected WP ME claims: wp/quiescence stage is not enabled".into(), + )); + } + + Ok(()) +} diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_claims.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_claims.rs new file mode 100644 index 00000000..2d6a322d --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_claims.rs @@ -0,0 +1,1143 @@ +use super::*; + +pub struct RouteAShoutTimeClaimsGuard<'a> { + pub lane_ranges: Vec>, + pub lanes: Vec>, + pub gamma_groups: Vec>, + pub bitness: Vec>>, +} + +pub struct RouteAShoutTimeLaneClaims<'a> { + pub value_prefix: RoundOraclePrefix<'a>, + pub adapter_prefix: RoundOraclePrefix<'a>, + pub event_table_hash_prefix: Option>, + pub value_claim: K, + pub adapter_claim: K, + pub event_table_hash_claim: Option, + pub gamma_group: Option, +} + +pub struct RouteAShoutTimeGammaGroupClaims<'a> { + pub value_prefix: RoundOraclePrefix<'a>, + pub adapter_prefix: RoundOraclePrefix<'a>, + pub value_claim: K, + pub adapter_claim: K, +} + +pub fn build_route_a_shout_time_claims_guard<'a>( + shout_oracles: &'a mut [RouteAShoutTimeOracles], + shout_gamma_groups: &'a mut [RouteAShoutGammaGroupOracles], + ell_n: usize, +) -> RouteAShoutTimeClaimsGuard<'a> { + let mut lane_ranges: Vec> = Vec::with_capacity(shout_oracles.len()); + let mut lanes: Vec> = Vec::new(); + let mut gamma_groups: Vec> = Vec::with_capacity(shout_gamma_groups.len()); + let mut bitness: Vec>> = Vec::with_capacity(shout_oracles.len()); + + for o in shout_oracles.iter_mut() { + bitness.push(core::mem::take(&mut o.bitness)); + let start = lanes.len(); + for lane in o.lanes.iter_mut() { + lanes.push(RouteAShoutTimeLaneClaims { + value_prefix: RoundOraclePrefix::new(lane.value.as_mut(), ell_n), + adapter_prefix: RoundOraclePrefix::new(lane.adapter.as_mut(), ell_n), + event_table_hash_prefix: lane + .event_table_hash + .as_deref_mut() + .map(|o| RoundOraclePrefix::new(o, ell_n)), + value_claim: lane.value_claim, + adapter_claim: lane.adapter_claim, + event_table_hash_claim: lane.event_table_hash_claim, + gamma_group: lane.gamma_group, + }); + } + let end = lanes.len(); + lane_ranges.push(start..end); + } + + for g in shout_gamma_groups.iter_mut() { + gamma_groups.push(RouteAShoutTimeGammaGroupClaims { + value_prefix: RoundOraclePrefix::new(g.value.as_mut(), ell_n), + adapter_prefix: RoundOraclePrefix::new(g.adapter.as_mut(), ell_n), + value_claim: g.value_claim, + adapter_claim: g.adapter_claim, + }); + } + + RouteAShoutTimeClaimsGuard { + lane_ranges, + lanes, + gamma_groups, + bitness, + } +} + +pub struct ShoutRouteAProtocol<'a> { + guard: RouteAShoutTimeClaimsGuard<'a>, +} + +impl<'a> ShoutRouteAProtocol<'a> { + pub fn new( + shout_oracles: &'a mut [RouteAShoutTimeOracles], + shout_gamma_groups: &'a mut [RouteAShoutGammaGroupOracles], + ell_n: usize, + ) -> Self { + Self { + guard: build_route_a_shout_time_claims_guard(shout_oracles, shout_gamma_groups, ell_n), + } + } +} + +impl<'o> TimeBatchedClaims for ShoutRouteAProtocol<'o> { + fn append_time_claims<'a>( + &'a mut self, + _ell_n: usize, + claimed_sums: &mut Vec, + degree_bounds: &mut Vec, + labels: &mut Vec<&'static [u8]>, + claim_is_dynamic: &mut Vec, + claims: &mut Vec>, + ) { + append_route_a_shout_time_claims( + &mut self.guard, + claimed_sums, + degree_bounds, + labels, + claim_is_dynamic, + claims, + ); + } +} + +pub fn append_route_a_shout_time_claims<'a>( + guard: &'a mut RouteAShoutTimeClaimsGuard<'_>, + claimed_sums: &mut Vec, + degree_bounds: &mut Vec, + labels: &mut Vec<&'static [u8]>, + claim_is_dynamic: &mut Vec, + claims: &mut Vec>, +) { + if guard.lane_ranges.is_empty() { + return; + } + if guard.bitness.len() != guard.lane_ranges.len() { + panic!("shout bitness count mismatch"); + } + + let mut lane_ranges_iter = guard.lane_ranges.iter(); + let mut next_end = lane_ranges_iter.next().expect("non-empty").end; + let mut bitness_iter = guard.bitness.iter_mut(); + + for (lane_idx, lane) in guard.lanes.iter_mut().enumerate() { + if lane.gamma_group.is_none() { + claimed_sums.push(lane.value_claim); + degree_bounds.push(lane.value_prefix.degree_bound()); + labels.push(b"shout/value"); + claim_is_dynamic.push(true); + claims.push(BatchedClaim { + oracle: &mut lane.value_prefix, + claimed_sum: lane.value_claim, + label: b"shout/value", + }); + + claimed_sums.push(lane.adapter_claim); + degree_bounds.push(lane.adapter_prefix.degree_bound()); + labels.push(b"shout/adapter"); + claim_is_dynamic.push(true); + claims.push(BatchedClaim { + oracle: &mut lane.adapter_prefix, + claimed_sum: lane.adapter_claim, + label: b"shout/adapter", + }); + } + + if let Some(prefix) = lane.event_table_hash_prefix.as_mut() { + let claim = lane + .event_table_hash_claim + .expect("event_table_hash_claim missing"); + claimed_sums.push(claim); + degree_bounds.push(prefix.degree_bound()); + labels.push(b"shout/event_table_hash"); + claim_is_dynamic.push(true); + claims.push(BatchedClaim { + oracle: prefix, + claimed_sum: claim, + label: b"shout/event_table_hash", + }); + } + + if lane_idx + 1 == next_end { + let bitness_vec = bitness_iter.next().expect("shout bitness idx drift"); + for bit_oracle in bitness_vec.iter_mut() { + claimed_sums.push(K::ZERO); + degree_bounds.push(bit_oracle.degree_bound()); + labels.push(b"shout/bitness"); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle: bit_oracle.as_mut(), + claimed_sum: K::ZERO, + label: b"shout/bitness", + }); + } + + next_end = lane_ranges_iter.next().map(|r| r.end).unwrap_or(usize::MAX); + } + } + + for group in guard.gamma_groups.iter_mut() { + claimed_sums.push(group.value_claim); + degree_bounds.push(group.value_prefix.degree_bound()); + labels.push(b"shout/value"); + claim_is_dynamic.push(true); + claims.push(BatchedClaim { + oracle: &mut group.value_prefix, + claimed_sum: group.value_claim, + label: b"shout/value", + }); + + claimed_sums.push(group.adapter_claim); + degree_bounds.push(group.adapter_prefix.degree_bound()); + labels.push(b"shout/adapter"); + claim_is_dynamic.push(true); + claims.push(BatchedClaim { + oracle: &mut group.adapter_prefix, + claimed_sum: group.adapter_claim, + label: b"shout/adapter", + }); + } + + if bitness_iter.next().is_some() { + panic!("shout bitness not fully consumed"); + } +} + +pub struct RouteATwistTimeClaimsGuard<'a> { + pub read_check_prefixes: Vec>, + pub write_check_prefixes: Vec>, + pub read_check_claims: Vec, + pub write_check_claims: Vec, + pub bitness: Vec>>, +} + +pub fn build_route_a_twist_time_claims_guard<'a>( + twist_oracles: &'a mut [RouteATwistTimeOracles], + ell_n: usize, + read_check_claims: Vec, + write_check_claims: Vec, +) -> RouteATwistTimeClaimsGuard<'a> { + let mut read_check_prefixes: Vec> = Vec::with_capacity(twist_oracles.len()); + let mut write_check_prefixes: Vec> = Vec::with_capacity(twist_oracles.len()); + let mut bitness: Vec>> = Vec::with_capacity(twist_oracles.len()); + + if read_check_claims.len() != twist_oracles.len() { + panic!( + "twist read-check claim count mismatch (claims={}, oracles={})", + read_check_claims.len(), + twist_oracles.len() + ); + } + if write_check_claims.len() != twist_oracles.len() { + panic!( + "twist write-check claim count mismatch (claims={}, oracles={})", + write_check_claims.len(), + twist_oracles.len() + ); + } + + for o in twist_oracles.iter_mut() { + bitness.push(core::mem::take(&mut o.bitness)); + read_check_prefixes.push(RoundOraclePrefix::new(o.read_check.as_mut(), ell_n)); + write_check_prefixes.push(RoundOraclePrefix::new(o.write_check.as_mut(), ell_n)); + } + + RouteATwistTimeClaimsGuard { + read_check_prefixes, + write_check_prefixes, + read_check_claims, + write_check_claims, + bitness, + } +} + +pub fn append_route_a_twist_time_claims<'a>( + guard: &'a mut RouteATwistTimeClaimsGuard<'_>, + claimed_sums: &mut Vec, + degree_bounds: &mut Vec, + labels: &mut Vec<&'static [u8]>, + claim_is_dynamic: &mut Vec, + claims: &mut Vec>, +) { + for (((read_check_time, write_check_time), bitness_vec), (read_claim, write_claim)) in guard + .read_check_prefixes + .iter_mut() + .zip(guard.write_check_prefixes.iter_mut()) + .zip(guard.bitness.iter_mut()) + .zip( + guard + .read_check_claims + .iter() + .zip(guard.write_check_claims.iter()), + ) + { + claimed_sums.push(*read_claim); + degree_bounds.push(read_check_time.degree_bound()); + labels.push(b"twist/read_check"); + claim_is_dynamic.push(true); + claims.push(BatchedClaim { + oracle: read_check_time, + claimed_sum: *read_claim, + label: b"twist/read_check", + }); + + claimed_sums.push(*write_claim); + degree_bounds.push(write_check_time.degree_bound()); + labels.push(b"twist/write_check"); + claim_is_dynamic.push(true); + claims.push(BatchedClaim { + oracle: write_check_time, + claimed_sum: *write_claim, + label: b"twist/write_check", + }); + + for bit_oracle in bitness_vec.iter_mut() { + claimed_sums.push(K::ZERO); + degree_bounds.push(bit_oracle.degree_bound()); + labels.push(b"twist/bitness"); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle: bit_oracle.as_mut(), + claimed_sum: K::ZERO, + label: b"twist/bitness", + }); + } + } +} + +pub struct TwistRouteAProtocol<'a> { + guard: RouteATwistTimeClaimsGuard<'a>, +} + +impl<'a> TwistRouteAProtocol<'a> { + pub fn new( + twist_oracles: &'a mut [RouteATwistTimeOracles], + ell_n: usize, + read_check_claims: Vec, + write_check_claims: Vec, + ) -> Self { + Self { + guard: build_route_a_twist_time_claims_guard(twist_oracles, ell_n, read_check_claims, write_check_claims), + } + } +} + +impl<'o> TimeBatchedClaims for TwistRouteAProtocol<'o> { + fn append_time_claims<'a>( + &'a mut self, + _ell_n: usize, + claimed_sums: &mut Vec, + degree_bounds: &mut Vec, + labels: &mut Vec<&'static [u8]>, + claim_is_dynamic: &mut Vec, + claims: &mut Vec>, + ) { + append_route_a_twist_time_claims( + &mut self.guard, + claimed_sums, + degree_bounds, + labels, + claim_is_dynamic, + claims, + ); + } +} + +#[inline] +pub(crate) fn has_trace_lookup_families_instance(step: &StepInstanceBundle) -> bool { + step.lut_insts + .iter() + .any(|inst| rv32_is_decode_lookup_table_id(inst.table_id) || rv32_is_width_lookup_table_id(inst.table_id)) +} + +#[inline] +pub(crate) fn has_trace_lookup_families_witness(step: &StepWitnessBundle) -> bool { + step.lut_instances + .iter() + .any(|(inst, _)| rv32_is_decode_lookup_table_id(inst.table_id) || rv32_is_width_lookup_table_id(inst.table_id)) +} + +#[inline] +pub(crate) fn wb_wp_required_for_step_instance(step: &StepInstanceBundle) -> bool { + // Stage gating is keyed by lookup-family presence instead of fixed `m_in`/`mem_id` + // assumptions so adapter-side routing can evolve without hardcoding RV32 shapes here. + has_trace_lookup_families_instance(step) +} + +#[inline] +pub(crate) fn wb_wp_required_for_step_witness(step: &StepWitnessBundle) -> bool { + has_trace_lookup_families_witness(step) +} + +pub(crate) fn build_bus_layout_for_step_witness( + step: &StepWitnessBundle, + t_len: usize, +) -> Result { + let m = step.mcs.1.Z.cols(); + let m_in = step.mcs.0.m_in; + let shout_shapes: Vec = step + .lut_instances + .iter() + .map(|(inst, _)| ShoutInstanceShape { + ell_addr: inst.d * inst.ell, + lanes: inst.lanes.max(1), + n_vals: 1usize, + addr_group: inst.addr_group, + selector_group: inst.selector_group, + }) + .collect(); + let grouped_shout_instances = shout_shapes + .iter() + .filter(|shape| shape.addr_group.is_some()) + .count(); + let twist = step + .mem_instances + .iter() + .map(|(inst, _)| (inst.d * inst.ell, inst.lanes.max(1))); + build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes(m, m_in, t_len, shout_shapes, twist).map_err( + |e| { + PiCcsError::InvalidInput(format!( + "step bus layout failed: m={m}, m_in={m_in}, t_len={t_len}, lut_insts={}, grouped_lut_insts={grouped_shout_instances}: {e}", + step.lut_instances.len() + )) + }, + ) +} + +#[inline] +pub(crate) fn decode_stage_required_for_step_instance(step: &StepInstanceBundle) -> bool { + wb_wp_required_for_step_instance(step) + && step + .lut_insts + .iter() + .any(|inst| rv32_is_decode_lookup_table_id(inst.table_id)) +} + +#[inline] +pub(crate) fn decode_stage_required_for_step_witness(step: &StepWitnessBundle) -> bool { + wb_wp_required_for_step_witness(step) + && step + .lut_instances + .iter() + .any(|(inst, _)| rv32_is_decode_lookup_table_id(inst.table_id)) +} + +#[inline] +pub(crate) fn width_stage_required_for_step_instance(step: &StepInstanceBundle) -> bool { + wb_wp_required_for_step_instance(step) + && step + .lut_insts + .iter() + .any(|inst| rv32_is_width_lookup_table_id(inst.table_id)) +} + +#[inline] +pub(crate) fn width_stage_required_for_step_witness(step: &StepWitnessBundle) -> bool { + wb_wp_required_for_step_witness(step) + && step + .lut_instances + .iter() + .any(|(inst, _)| rv32_is_width_lookup_table_id(inst.table_id)) +} + +#[inline] +pub(crate) fn control_stage_required_for_step_instance(step: &StepInstanceBundle) -> bool { + decode_stage_required_for_step_instance(step) +} + +#[inline] +pub(crate) fn control_stage_required_for_step_witness(step: &StepWitnessBundle) -> bool { + decode_stage_required_for_step_witness(step) +} + +pub(crate) fn build_route_a_wb_wp_time_claims( + params: &NeoParams, + step: &StepWitnessBundle, + r_cycle: &[K], +) -> Result<(Option<(Box, K)>, Option<(Box, K)>), PiCcsError> { + if !wb_wp_required_for_step_witness(step) { + return Ok((None, None)); + } + + let trace = Rv32TraceLayout::new(); + let t_len = infer_rv32_trace_t_len_for_wb_wp(step, &trace)?; + let m_in = step.mcs.0.m_in; + let ell_n = r_cycle.len(); + let wb_bool_cols = rv32_trace_wb_columns(&trace); + let wp_cols = rv32_trace_wp_columns(&trace); + + let mut decode_cols = Vec::with_capacity(1 + wb_bool_cols.len() + wp_cols.len()); + decode_cols.push(trace.active); + decode_cols.extend(wb_bool_cols.iter().copied()); + decode_cols.extend(wp_cols.iter().copied()); + let decoded = decode_trace_col_values_batch(params, step, t_len, &decode_cols)?; + + let wb_weights = wb_weight_vector(r_cycle, wb_bool_cols.len()); + let mut wb_bool_sparse_cols: Vec> = Vec::with_capacity(wb_bool_cols.len()); + for &col_id in wb_bool_cols.iter() { + let vals = decoded + .get(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("WB: missing decoded bool column {col_id}")))?; + wb_bool_sparse_cols.push(sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + + let wb_oracle = LazyWeightedBitnessOracleSparseTime::new_with_cycle(r_cycle, wb_bool_sparse_cols, wb_weights); + + let wp_cols = rv32_trace_wp_columns(&trace); + let weights = wp_weight_vector(r_cycle, wp_cols.len()); + let active_vals = decoded + .get(&trace.active) + .ok_or_else(|| PiCcsError::ProtocolError(format!("WP: missing decoded active column {}", trace.active)))?; + let active = sparse_trace_col_from_values(m_in, ell_n, &active_vals)?; + + let mut sparse_cols: Vec> = Vec::with_capacity(wp_cols.len()); + for &col_id in wp_cols.iter() { + let vals = decoded + .get(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("WP: missing decoded column {col_id}")))?; + sparse_cols.push(sparse_trace_col_from_values(m_in, ell_n, &vals)?); + } + + let oracle = WeightedMaskOracleSparseTime::new(active, sparse_cols, weights, r_cycle); + Ok((Some((Box::new(wb_oracle), K::ZERO)), Some((Box::new(oracle), K::ZERO)))) +} + +pub(crate) fn build_route_a_decode_time_claims( + params: &NeoParams, + step: &StepWitnessBundle, + r_cycle: &[K], +) -> Result<(Option<(Box, K)>, Option<(Box, K)>), PiCcsError> { + if !decode_stage_required_for_step_witness(step) { + return Ok((None, None)); + } + + let trace = Rv32TraceLayout::new(); + let decode = Rv32DecodeSidecarLayout::new(); + let t_len = infer_rv32_trace_t_len_for_wb_wp(step, &trace)?; + let m_in = step.mcs.0.m_in; + let ell_n = r_cycle.len(); + + let cpu_cols = vec![ + trace.active, + trace.halted, + trace.instr_word, + trace.rs1_val, + trace.rs2_val, + trace.rd_val, + trace.ram_addr, + trace.shout_has_lookup, + trace.shout_val, + trace.shout_lhs, + trace.shout_rhs, + ]; + let cpu_decoded = decode_trace_col_values_batch(params, step, t_len, &cpu_cols)?; + + let decode_decoded = { + let instr_vals = cpu_decoded + .get(&trace.instr_word) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing instr_word decode column".into()))?; + let active_vals = cpu_decoded + .get(&trace.active) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing active decode column".into()))?; + if instr_vals.len() != t_len || active_vals.len() != t_len { + return Err(PiCcsError::ProtocolError(format!( + "W2(shared): decoded CPU column lengths drift (instr={}, active={}, t_len={t_len})", + instr_vals.len(), + active_vals.len() + ))); + } + let mut decoded = BTreeMap::>::new(); + for col_id in 0..decode.cols { + decoded.insert(col_id, Vec::with_capacity(t_len)); + } + for j in 0..t_len { + let instr_word = decode_k_to_u32(instr_vals[j], "W2(shared)/instr_word")?; + let active = active_vals[j] != K::ZERO; + let mut row = rv32_decode_lookup_backed_row_from_instr_word(&decode, instr_word, active); + if !active { + row.fill(F::ZERO); + } + for (col_id, value) in row.into_iter().enumerate() { + decoded + .get_mut(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): decode map build failed".into()))? + .push(K::from(value)); + } + } + + // In shared lookup-backed mode, overwrite lookup-backed decode columns with the values + // actually committed on the shared Shout bus so prover oracles and verifier terminals + // are sourced from identical openings. + let (decode_open_cols, decode_lut_indices) = resolve_shared_decode_lookup_lut_indices(step, &decode)?; + let bus = build_bus_layout_for_step_witness(step, t_len)?; + if bus.shout_cols.len() != step.lut_instances.len() { + return Err(PiCcsError::ProtocolError( + "W2(shared): bus layout shout lane count drift".into(), + )); + } + let mut bus_val_cols = Vec::with_capacity(decode_open_cols.len()); + for &lut_idx in decode_lut_indices.iter() { + let inst_cols = bus.shout_cols.get(lut_idx).ok_or_else(|| { + PiCcsError::ProtocolError("W2(shared): missing shout cols for decode lookup table".into()) + })?; + let lane0 = inst_cols.lanes.get(0).ok_or_else(|| { + PiCcsError::ProtocolError("W2(shared): expected one shout lane for decode lookup table".into()) + })?; + bus_val_cols.push(lane0.primary_val()); + } + let lookup_vals = decode_lookup_backed_col_values_batch( + params, + bus.bus_base, + t_len, + &step.mcs.1.Z, + bus.bus_cols, + &bus_val_cols, + )?; + for (open_idx, &decode_col_id) in decode_open_cols.iter().enumerate() { + let bus_col_id = bus_val_cols[open_idx]; + let values = lookup_vals.get(&bus_col_id).ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "W2(shared): missing decoded lookup values for bus_col={bus_col_id}" + )) + })?; + decoded.insert(decode_col_id, values.clone()); + } + + // Recompute derived decode helper columns from opened lookup-backed decode columns. + let rd_is_zero_vals = decoded + .get(&decode.rd_is_zero) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing rd_is_zero decode column".into()))?; + let funct7_b5_vals = decoded + .get(&decode.funct7_bit[5]) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing funct7_bit[5] decode column".into()))?; + let op_lui_vals = decoded + .get(&decode.op_lui) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing op_lui decode column".into()))?; + let op_auipc_vals = decoded + .get(&decode.op_auipc) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing op_auipc decode column".into()))?; + let op_jal_vals = decoded + .get(&decode.op_jal) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing op_jal decode column".into()))?; + let op_jalr_vals = decoded + .get(&decode.op_jalr) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing op_jalr decode column".into()))?; + let op_alu_imm_vals = decoded + .get(&decode.op_alu_imm) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing op_alu_imm decode column".into()))?; + let op_alu_reg_vals = decoded + .get(&decode.op_alu_reg) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing op_alu_reg decode column".into()))?; + let funct3_is0_vals = decoded + .get(&decode.funct3_is[0]) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing funct3_is[0] decode column".into()))?; + let funct3_is1_vals = decoded + .get(&decode.funct3_is[1]) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing funct3_is[1] decode column".into()))?; + let funct3_is5_vals = decoded + .get(&decode.funct3_is[5]) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing funct3_is[5] decode column".into()))?; + let rs2_vals = decoded + .get(&decode.rs2) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing rs2 decode column".into()))?; + let imm_i_vals = decoded + .get(&decode.imm_i) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): missing imm_i decode column".into()))?; + + let mut op_lui_write = Vec::with_capacity(t_len); + let mut op_auipc_write = Vec::with_capacity(t_len); + let mut op_jal_write = Vec::with_capacity(t_len); + let mut op_jalr_write = Vec::with_capacity(t_len); + let mut op_alu_imm_write = Vec::with_capacity(t_len); + let mut op_alu_reg_write = Vec::with_capacity(t_len); + let mut alu_reg_delta = Vec::with_capacity(t_len); + let mut alu_imm_delta = Vec::with_capacity(t_len); + let mut alu_imm_shift_rhs_delta = Vec::with_capacity(t_len); + for j in 0..t_len { + let rd_keep = K::ONE - rd_is_zero_vals[j]; + op_lui_write.push(op_lui_vals[j] * rd_keep); + op_auipc_write.push(op_auipc_vals[j] * rd_keep); + op_jal_write.push(op_jal_vals[j] * rd_keep); + op_jalr_write.push(op_jalr_vals[j] * rd_keep); + op_alu_imm_write.push(op_alu_imm_vals[j] * rd_keep); + op_alu_reg_write.push(op_alu_reg_vals[j] * rd_keep); + alu_reg_delta.push(funct7_b5_vals[j] * (funct3_is0_vals[j] + funct3_is5_vals[j])); + alu_imm_delta.push(funct7_b5_vals[j] * funct3_is5_vals[j]); + alu_imm_shift_rhs_delta.push((funct3_is1_vals[j] + funct3_is5_vals[j]) * (rs2_vals[j] - imm_i_vals[j])); + } + decoded.insert(decode.op_lui_write, op_lui_write); + decoded.insert(decode.op_auipc_write, op_auipc_write); + decoded.insert(decode.op_jal_write, op_jal_write); + decoded.insert(decode.op_jalr_write, op_jalr_write); + decoded.insert(decode.op_alu_imm_write, op_alu_imm_write); + decoded.insert(decode.op_alu_reg_write, op_alu_reg_write); + decoded.insert(decode.alu_reg_table_delta, alu_reg_delta); + decoded.insert(decode.alu_imm_table_delta, alu_imm_delta); + decoded.insert(decode.alu_imm_shift_rhs_delta, alu_imm_shift_rhs_delta); + + decoded + }; + + let cpu_value_at = |col_id: usize, row: usize| -> Result { + cpu_decoded + .get(&col_id) + .and_then(|v| v.get(row)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2 missing CPU decoded column {col_id}"))) + }; + let decode_value_at = |col_id: usize, row: usize| -> Result { + decode_decoded + .get(&col_id) + .and_then(|v| v.get(row)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2 missing decode lookup-backed column {col_id}"))) + }; + + let mut imm_residual_vals: Vec> = (0..W2_IMM_RESIDUAL_COUNT) + .map(|_| Vec::with_capacity(t_len)) + .collect(); + for j in 0..t_len { + let active = cpu_value_at(trace.active, j)?; + let halted = cpu_value_at(trace.halted, j)?; + let decode_opcode = decode_value_at(decode.opcode, j)?; + let rd_has_write = decode_value_at(decode.rd_has_write, j)?; + let rd_is_zero = decode_value_at(decode.rd_is_zero, j)?; + let rs1_val = cpu_value_at(trace.rs1_val, j)?; + let rs2_val = cpu_value_at(trace.rs2_val, j)?; + let rd_val = cpu_value_at(trace.rd_val, j)?; + let ram_has_read = decode_value_at(decode.ram_has_read, j)?; + let ram_has_write = decode_value_at(decode.ram_has_write, j)?; + let ram_addr = cpu_value_at(trace.ram_addr, j)?; + let shout_has_lookup = cpu_value_at(trace.shout_has_lookup, j)?; + let shout_val = cpu_value_at(trace.shout_val, j)?; + let shout_lhs = cpu_value_at(trace.shout_lhs, j)?; + let shout_rhs = cpu_value_at(trace.shout_rhs, j)?; + let opcode_flags = [ + decode_value_at(decode.op_lui, j)?, + decode_value_at(decode.op_auipc, j)?, + decode_value_at(decode.op_jal, j)?, + decode_value_at(decode.op_jalr, j)?, + decode_value_at(decode.op_branch, j)?, + decode_value_at(decode.op_load, j)?, + decode_value_at(decode.op_store, j)?, + decode_value_at(decode.op_alu_imm, j)?, + decode_value_at(decode.op_alu_reg, j)?, + decode_value_at(decode.op_misc_mem, j)?, + decode_value_at(decode.op_system, j)?, + decode_value_at(decode.op_amo, j)?, + ]; + let funct3_is = [ + decode_value_at(decode.funct3_is[0], j)?, + decode_value_at(decode.funct3_is[1], j)?, + decode_value_at(decode.funct3_is[2], j)?, + decode_value_at(decode.funct3_is[3], j)?, + decode_value_at(decode.funct3_is[4], j)?, + decode_value_at(decode.funct3_is[5], j)?, + decode_value_at(decode.funct3_is[6], j)?, + decode_value_at(decode.funct3_is[7], j)?, + ]; + let rs2_decode = decode_value_at(decode.rs2, j)?; + let imm_i = decode_value_at(decode.imm_i, j)?; + let imm_s = decode_value_at(decode.imm_s, j)?; + + let funct3_bits = [ + decode_value_at(decode.funct3_bit[0], j)?, + decode_value_at(decode.funct3_bit[1], j)?, + decode_value_at(decode.funct3_bit[2], j)?, + ]; + let funct7_bits = [ + decode_value_at(decode.funct7_bit[0], j)?, + decode_value_at(decode.funct7_bit[1], j)?, + decode_value_at(decode.funct7_bit[2], j)?, + decode_value_at(decode.funct7_bit[3], j)?, + decode_value_at(decode.funct7_bit[4], j)?, + decode_value_at(decode.funct7_bit[5], j)?, + decode_value_at(decode.funct7_bit[6], j)?, + ]; + let imm = w2_decode_immediate_residuals( + decode_value_at(decode.imm_i, j)?, + decode_value_at(decode.imm_s, j)?, + decode_value_at(decode.imm_b, j)?, + decode_value_at(decode.imm_j, j)?, + [ + decode_value_at(decode.rd_bit[0], j)?, + decode_value_at(decode.rd_bit[1], j)?, + decode_value_at(decode.rd_bit[2], j)?, + decode_value_at(decode.rd_bit[3], j)?, + decode_value_at(decode.rd_bit[4], j)?, + ], + funct3_bits, + [ + decode_value_at(decode.rs1_bit[0], j)?, + decode_value_at(decode.rs1_bit[1], j)?, + decode_value_at(decode.rs1_bit[2], j)?, + decode_value_at(decode.rs1_bit[3], j)?, + decode_value_at(decode.rs1_bit[4], j)?, + ], + [ + decode_value_at(decode.rs2_bit[0], j)?, + decode_value_at(decode.rs2_bit[1], j)?, + decode_value_at(decode.rs2_bit[2], j)?, + decode_value_at(decode.rs2_bit[3], j)?, + decode_value_at(decode.rs2_bit[4], j)?, + ], + funct7_bits, + ); + + let op_write_flags = [ + opcode_flags[0] * (K::ONE - rd_is_zero), + opcode_flags[1] * (K::ONE - rd_is_zero), + opcode_flags[2] * (K::ONE - rd_is_zero), + opcode_flags[3] * (K::ONE - rd_is_zero), + opcode_flags[7] * (K::ONE - rd_is_zero), + opcode_flags[8] * (K::ONE - rd_is_zero), + ]; + let shout_table_id = decode_value_at(decode.shout_table_id, j)?; + let alu_reg_table_delta = funct7_bits[5] * (funct3_is[0] + funct3_is[5]); + let alu_imm_table_delta = funct7_bits[5] * funct3_is[5]; + let alu_imm_shift_rhs_delta = (funct3_is[1] + funct3_is[5]) * (rs2_decode - imm_i); + let selector_residuals = w2_decode_selector_residuals( + active, + decode_opcode, + opcode_flags, + funct3_is, + funct3_bits, + opcode_flags[11], + ); + let bitness_residuals = w2_decode_bitness_residuals(opcode_flags, funct3_is); + let alu_branch_residuals = w2_alu_branch_lookup_residuals( + active, + halted, + shout_has_lookup, + shout_lhs, + shout_rhs, + shout_table_id, + rs1_val, + rs2_val, + rd_has_write, + rd_is_zero, + rd_val, + ram_has_read, + ram_has_write, + ram_addr, + shout_val, + funct3_bits, + funct7_bits, + opcode_flags, + op_write_flags, + funct3_is, + alu_reg_table_delta, + alu_imm_table_delta, + alu_imm_shift_rhs_delta, + rs2_decode, + imm_i, + imm_s, + ); + if let Some((idx, _)) = selector_residuals + .iter() + .enumerate() + .find(|(_, r)| **r != K::ZERO) + { + return Err(PiCcsError::ProtocolError(format!( + "decode/fields selector residual non-zero at row={j}, idx={idx}" + ))); + } + if let Some((idx, _)) = bitness_residuals + .iter() + .enumerate() + .find(|(_, r)| **r != K::ZERO) + { + return Err(PiCcsError::ProtocolError(format!( + "decode/fields bitness residual non-zero at row={j}, idx={idx}" + ))); + } + if let Some((idx, _)) = alu_branch_residuals + .iter() + .enumerate() + .find(|(_, r)| **r != K::ZERO) + { + return Err(PiCcsError::ProtocolError(format!( + "decode/fields alu_branch residual non-zero at row={j}, idx={idx}" + ))); + } + + for (k, r) in imm.iter().enumerate() { + imm_residual_vals[k].push(*r); + } + } + + let main_field_cols = vec![ + trace.active, + trace.halted, + trace.rs1_val, + trace.rs2_val, + trace.rd_val, + trace.ram_addr, + trace.shout_has_lookup, + trace.shout_val, + trace.shout_lhs, + trace.shout_rhs, + ]; + let decode_field_cols = vec![ + decode.opcode, + decode.rd_is_zero, + decode.rd_has_write, + decode.ram_has_read, + decode.ram_has_write, + decode.shout_table_id, + decode.op_lui, + decode.op_auipc, + decode.op_jal, + decode.op_jalr, + decode.op_branch, + decode.op_load, + decode.op_store, + decode.op_alu_imm, + decode.op_alu_reg, + decode.op_misc_mem, + decode.op_system, + decode.op_amo, + decode.funct3_is[0], + decode.funct3_is[1], + decode.funct3_is[2], + decode.funct3_is[3], + decode.funct3_is[4], + decode.funct3_is[5], + decode.funct3_is[6], + decode.funct3_is[7], + decode.funct3_bit[0], + decode.funct3_bit[1], + decode.funct3_bit[2], + decode.funct7_bit[0], + decode.funct7_bit[1], + decode.funct7_bit[2], + decode.funct7_bit[3], + decode.funct7_bit[4], + decode.funct7_bit[5], + decode.funct7_bit[6], + decode.rs2, + decode.imm_i, + decode.imm_s, + ]; + let mut main_sparse = BTreeMap::>::new(); + for &col_id in main_field_cols.iter() { + let vals = cpu_decoded + .get(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2 missing CPU decoded column {col_id}")))?; + main_sparse.insert(col_id, sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + let mut decode_sparse = BTreeMap::>::new(); + for &col_id in decode_field_cols.iter() { + let vals = decode_decoded + .get(&col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2 missing decode lookup-backed column {col_id}")))?; + decode_sparse.insert(col_id, sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + let main_col = |col_id: usize| -> Result, PiCcsError> { + main_sparse + .get(&col_id) + .cloned() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2 missing main sparse column {col_id}"))) + }; + let decode_col = |col_id: usize| -> Result, PiCcsError> { + decode_sparse + .get(&col_id) + .cloned() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2 missing decode sparse column {col_id}"))) + }; + + let mut fields_sparse_cols = Vec::with_capacity(main_field_cols.len() + decode_field_cols.len()); + for &col_id in main_field_cols.iter() { + fields_sparse_cols.push(main_col(col_id)?); + } + for &col_id in decode_field_cols.iter() { + fields_sparse_cols.push(decode_col(col_id)?); + } + + let mut imm_sparse_cols = Vec::with_capacity(imm_residual_vals.len()); + for vals in imm_residual_vals.iter() { + imm_sparse_cols.push(sparse_trace_col_from_values(m_in, ell_n, vals)?); + } + + let pow2_cycle = 1usize + .checked_shl(ell_n as u32) + .ok_or_else(|| PiCcsError::InvalidInput("W2: 2^ell_n overflow".into()))?; + let active_zero = SparseIdxVec::from_entries(pow2_cycle, Vec::new()); + let fields_weights = w2_decode_pack_weight_vector(r_cycle, W2_FIELDS_RESIDUAL_COUNT); + let fields_oracle = FormulaOracleSparseTime::new( + fields_sparse_cols, + 4, + r_cycle, + Box::new(move |vals: &[K]| { + let mut idx = 0usize; + let active = vals[idx]; + idx += 1; + let halted = vals[idx]; + idx += 1; + let rs1_val = vals[idx]; + idx += 1; + let rs2_val = vals[idx]; + idx += 1; + let rd_val = vals[idx]; + idx += 1; + let ram_addr = vals[idx]; + idx += 1; + let shout_has_lookup = vals[idx]; + idx += 1; + let shout_val = vals[idx]; + idx += 1; + let shout_lhs = vals[idx]; + idx += 1; + let shout_rhs = vals[idx]; + idx += 1; + let decode_opcode = vals[idx]; + idx += 1; + let rd_is_zero = vals[idx]; + idx += 1; + let rd_has_write = vals[idx]; + idx += 1; + let ram_has_read = vals[idx]; + idx += 1; + let ram_has_write = vals[idx]; + idx += 1; + let shout_table_id = vals[idx]; + idx += 1; + let opcode_flags = [ + vals[idx], + vals[idx + 1], + vals[idx + 2], + vals[idx + 3], + vals[idx + 4], + vals[idx + 5], + vals[idx + 6], + vals[idx + 7], + vals[idx + 8], + vals[idx + 9], + vals[idx + 10], + vals[idx + 11], + ]; + idx += 12; + let funct3_is = [ + vals[idx], + vals[idx + 1], + vals[idx + 2], + vals[idx + 3], + vals[idx + 4], + vals[idx + 5], + vals[idx + 6], + vals[idx + 7], + ]; + idx += 8; + let funct3_bits = [vals[idx], vals[idx + 1], vals[idx + 2]]; + idx += 3; + let funct7_bits = [ + vals[idx], + vals[idx + 1], + vals[idx + 2], + vals[idx + 3], + vals[idx + 4], + vals[idx + 5], + vals[idx + 6], + ]; + idx += 7; + let rs2_decode = vals[idx]; + idx += 1; + let imm_i = vals[idx]; + idx += 1; + let imm_s = vals[idx]; + let rd_keep = K::ONE - rd_is_zero; + let op_write_flags = [ + opcode_flags[0] * rd_keep, + opcode_flags[1] * rd_keep, + opcode_flags[2] * rd_keep, + opcode_flags[3] * rd_keep, + opcode_flags[7] * rd_keep, + opcode_flags[8] * rd_keep, + ]; + let alu_reg_table_delta = funct7_bits[5] * (funct3_is[0] + funct3_is[5]); + let alu_imm_table_delta = funct7_bits[5] * funct3_is[5]; + let alu_imm_shift_rhs_delta = (funct3_is[1] + funct3_is[5]) * (rs2_decode - imm_i); + let selector_residuals = w2_decode_selector_residuals( + active, + decode_opcode, + opcode_flags, + funct3_is, + funct3_bits, + opcode_flags[11], + ); + let bitness_residuals = w2_decode_bitness_residuals(opcode_flags, funct3_is); + let alu_branch_residuals = w2_alu_branch_lookup_residuals( + active, + halted, + shout_has_lookup, + shout_lhs, + shout_rhs, + shout_table_id, + rs1_val, + rs2_val, + rd_has_write, + rd_is_zero, + rd_val, + ram_has_read, + ram_has_write, + ram_addr, + shout_val, + funct3_bits, + funct7_bits, + opcode_flags, + op_write_flags, + funct3_is, + alu_reg_table_delta, + alu_imm_table_delta, + alu_imm_shift_rhs_delta, + rs2_decode, + imm_i, + imm_s, + ); + let mut weighted = K::ZERO; + let mut w_idx = 0usize; + for r in selector_residuals { + weighted += fields_weights[w_idx] * r; + w_idx += 1; + } + for r in bitness_residuals { + weighted += fields_weights[w_idx] * r; + w_idx += 1; + } + for r in alu_branch_residuals { + weighted += fields_weights[w_idx] * r; + w_idx += 1; + } + debug_assert_eq!(w_idx, fields_weights.len()); + debug_assert_eq!(idx + 1, vals.len()); + weighted + }), + ); + let imm_oracle = WeightedMaskOracleSparseTime::new( + active_zero, + imm_sparse_cols, + w2_decode_imm_weight_vector(r_cycle, 4), + r_cycle, + ); + + Ok(( + Some((Box::new(fields_oracle), K::ZERO)), + Some((Box::new(imm_oracle), K::ZERO)), + )) +} + +pub(crate) type W3TimeClaims = ( + Option<(Box, K)>, + Option<(Box, K)>, + Option<(Box, K)>, + Option<(Box, K)>, + Option<(Box, K)>, +); diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_finalize.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_finalize.rs new file mode 100644 index 00000000..c4902294 --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_finalize.rs @@ -0,0 +1,493 @@ +use super::*; + +pub(crate) fn finalize_route_a_memory_prover( + tr: &mut Poseidon2Transcript, + params: &NeoParams, + cpu_bus: &BusLayout, + s: &CcsStructure, + step: &StepWitnessBundle, + prev_step: Option<&StepWitnessBundle>, + prev_twist_decoded: Option<&[TwistDecodedColsSparse]>, + oracles: &mut RouteAMemoryOracles, + shout_addr_pre: &ShoutAddrPreProof, + twist_pre: &[TwistAddrPreProverData], + r_time: &[K], + m_in: usize, + step_idx: usize, +) -> Result, PiCcsError> { + let has_prev = prev_step.is_some(); + if has_prev != prev_twist_decoded.is_some() { + return Err(PiCcsError::InvalidInput(format!( + "Twist rollover decoded cache mismatch: prev_step.is_some()={} but prev_twist_decoded.is_some()={}", + has_prev, + prev_twist_decoded.is_some() + ))); + } + let total_lanes: usize = step + .lut_instances + .iter() + .map(|(inst, _)| inst.lanes.max(1)) + .sum(); + if shout_addr_pre.claimed_sums.len() != total_lanes { + return Err(PiCcsError::InvalidInput(format!( + "shout addr-pre proof count mismatch (expected claimed_sums.len()=total_lanes={}, got {})", + total_lanes, + shout_addr_pre.claimed_sums.len(), + ))); + } + { + let mut lane_ell_addr: Vec = Vec::with_capacity(total_lanes); + let mut required_ell_addrs: std::collections::BTreeSet = std::collections::BTreeSet::new(); + for (lut_inst, _lut_wit) in step.lut_instances.iter().map(|(inst, wit)| (inst, wit)) { + let inst_ell_addr = lut_inst.d * lut_inst.ell; + let inst_ell_addr_u32 = u32::try_from(inst_ell_addr) + .map_err(|_| PiCcsError::InvalidInput("Shout: ell_addr overflows u32".into()))?; + required_ell_addrs.insert(inst_ell_addr_u32); + for _lane_idx in 0..lut_inst.lanes.max(1) { + lane_ell_addr.push(inst_ell_addr_u32); + } + } + if lane_ell_addr.len() != total_lanes { + return Err(PiCcsError::ProtocolError( + "shout addr-pre lane indexing drift (lane_ell_addr)".into(), + )); + } + + if shout_addr_pre.groups.len() != required_ell_addrs.len() { + return Err(PiCcsError::InvalidInput(format!( + "shout addr-pre group count mismatch (expected {}, got {})", + required_ell_addrs.len(), + shout_addr_pre.groups.len() + ))); + } + let required_list: Vec = required_ell_addrs.into_iter().collect(); + let mut seen_active: std::collections::HashSet = std::collections::HashSet::new(); + for (idx, group) in shout_addr_pre.groups.iter().enumerate() { + let expected_ell_addr = required_list[idx]; + if group.ell_addr != expected_ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "shout addr-pre groups not sorted or mismatched: groups[{idx}].ell_addr={} but expected {expected_ell_addr}", + group.ell_addr + ))); + } + if group.r_addr.len() != group.ell_addr as usize { + return Err(PiCcsError::InvalidInput(format!( + "shout addr-pre group ell_addr={} has r_addr.len()={}, expected {}", + group.ell_addr, + group.r_addr.len(), + group.ell_addr + ))); + } + if group.round_polys.len() != group.active_lanes.len() { + return Err(PiCcsError::InvalidInput(format!( + "shout addr-pre group ell_addr={} round_polys.len()={}, expected active_lanes.len()={}", + group.ell_addr, + group.round_polys.len(), + group.active_lanes.len() + ))); + } + for (pos, &lane_idx) in group.active_lanes.iter().enumerate() { + let lane_idx_usize = lane_idx as usize; + if lane_idx_usize >= total_lanes { + return Err(PiCcsError::InvalidInput( + "shout addr-pre active_lanes has index out of range".into(), + )); + } + if lane_ell_addr[lane_idx_usize] != group.ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "shout addr-pre active_lanes contains lane_idx={} with ell_addr={}, but group ell_addr={}", + lane_idx, lane_ell_addr[lane_idx_usize], group.ell_addr + ))); + } + if pos > 0 && group.active_lanes[pos - 1] >= lane_idx { + return Err(PiCcsError::InvalidInput( + "shout addr-pre active_lanes must be strictly increasing".into(), + )); + } + if !seen_active.insert(lane_idx) { + return Err(PiCcsError::InvalidInput( + "shout addr-pre active_lanes contains duplicates across groups".into(), + )); + } + } + for (pos, rounds) in group.round_polys.iter().enumerate() { + if rounds.len() != group.ell_addr as usize { + return Err(PiCcsError::InvalidInput(format!( + "shout addr-pre group ell_addr={} round_polys[{pos}].len()={}, expected {}", + group.ell_addr, + rounds.len(), + group.ell_addr + ))); + } + } + } + } + if twist_pre.len() != step.mem_instances.len() { + return Err(PiCcsError::InvalidInput(format!( + "twist pre-time count mismatch (expected {}, got {})", + step.mem_instances.len(), + twist_pre.len() + ))); + } + if oracles.twist.len() != step.mem_instances.len() { + return Err(PiCcsError::InvalidInput(format!( + "twist oracle count mismatch (expected {}, got {})", + step.mem_instances.len(), + oracles.twist.len() + ))); + } + + for (idx, (lut_inst, lut_wit)) in step.lut_instances.iter().enumerate() { + if !lut_inst.comms.is_empty() || !lut_wit.mats.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "shared CPU bus requires metadata-only Shout instances (comms/mats must be empty, lut_idx={idx})" + ))); + } + } + for (idx, (mem_inst, mem_wit)) in step.mem_instances.iter().enumerate() { + if !mem_inst.comms.is_empty() || !mem_wit.mats.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "shared CPU bus requires metadata-only Twist instances (comms/mats must be empty, mem_idx={idx})" + ))); + } + } + if let Some(prev) = prev_step { + if prev.mem_instances.len() != step.mem_instances.len() { + return Err(PiCcsError::InvalidInput(format!( + "Twist rollover requires stable mem instance count: prev has {}, current has {}", + prev.mem_instances.len(), + step.mem_instances.len() + ))); + } + for (idx, (mem_inst, mem_wit)) in prev.mem_instances.iter().enumerate() { + if !mem_inst.comms.is_empty() || !mem_wit.mats.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "shared CPU bus requires metadata-only Twist instances (comms/mats must be empty, prev mem_idx={idx})" + ))); + } + } + } + + let mut val_me_claims: Vec> = Vec::new(); + let mut wb_me_claims: Vec> = Vec::new(); + let mut wp_me_claims: Vec> = Vec::new(); + let mut proofs: Vec = Vec::new(); + + // -------------------------------------------------------------------- + // Phase 2: Twist val-eval sum-check (batched across mem instances). + // -------------------------------------------------------------------- + let mut twist_val_eval_proofs: Vec> = Vec::new(); + let mut r_val: Vec = Vec::new(); + if !step.mem_instances.is_empty() { + let plan = crate::memory_sidecar::claim_plan::TwistValEvalClaimPlan::build( + step.mem_instances.iter().map(|(inst, _)| inst), + has_prev, + ); + let n_mem = step.mem_instances.len(); + let claims_per_mem = plan.claims_per_mem; + let claim_count = plan.claim_count; + + let mut val_oracles: Vec> = Vec::with_capacity(claim_count); + let mut bind_claims: Vec<(u8, K)> = Vec::with_capacity(claim_count); + let mut claimed_sums: Vec = Vec::with_capacity(claim_count); + + let mut claimed_inc_sums_lt: Vec = Vec::with_capacity(n_mem); + let mut claimed_inc_sums_total: Vec = Vec::with_capacity(n_mem); + let mut claimed_prev_inc_sums_total: Vec> = Vec::with_capacity(n_mem); + + let mut claim_idx = 0usize; + for (i_mem, (mem_inst, _mem_wit)) in step.mem_instances.iter().enumerate() { + let pre = twist_pre + .get(i_mem) + .ok_or_else(|| PiCcsError::ProtocolError("missing Twist pre-time data".into()))?; + let decoded = &pre.decoded; + let r_addr = &pre.addr_pre.r_addr; + if decoded.lanes.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "Twist(Route A): decoded lanes empty at mem_idx={i_mem}" + ))); + } + + let mut lt_oracles: Vec> = Vec::with_capacity(decoded.lanes.len()); + let mut claimed_inc_sum_lt = K::ZERO; + for lane in decoded.lanes.iter() { + let (oracle, claim) = TwistValEvalOracleSparseTime::new( + lane.wa_bits.clone(), + lane.has_write.clone(), + lane.inc_at_write_addr.clone(), + r_addr, + r_time, + ); + lt_oracles.push(Box::new(oracle)); + claimed_inc_sum_lt += claim; + } + let oracle_lt: Box = Box::new(SumRoundOracle::new(lt_oracles)); + + let mut total_oracles: Vec> = Vec::with_capacity(decoded.lanes.len()); + let mut claimed_inc_sum_total = K::ZERO; + for lane in decoded.lanes.iter() { + let (oracle, claim) = TwistTotalIncOracleSparseTime::new( + lane.wa_bits.clone(), + lane.has_write.clone(), + lane.inc_at_write_addr.clone(), + r_addr, + ); + total_oracles.push(Box::new(oracle)); + claimed_inc_sum_total += claim; + } + let oracle_total: Box = Box::new(SumRoundOracle::new(total_oracles)); + + val_oracles.push(oracle_lt); + bind_claims.push((plan.bind_tags[claim_idx], claimed_inc_sum_lt)); + claimed_sums.push(claimed_inc_sum_lt); + claim_idx += 1; + + val_oracles.push(oracle_total); + bind_claims.push((plan.bind_tags[claim_idx], claimed_inc_sum_total)); + claimed_sums.push(claimed_inc_sum_total); + claim_idx += 1; + + claimed_inc_sums_lt.push(claimed_inc_sum_lt); + claimed_inc_sums_total.push(claimed_inc_sum_total); + + if let Some(prev) = prev_step { + let (prev_inst, _prev_wit) = prev + .mem_instances + .get(i_mem) + .ok_or_else(|| PiCcsError::ProtocolError("missing prev mem instance".into()))?; + if prev_inst.d != mem_inst.d + || prev_inst.ell != mem_inst.ell + || prev_inst.k != mem_inst.k + || prev_inst.lanes != mem_inst.lanes + { + return Err(PiCcsError::InvalidInput(format!( + "Twist rollover requires stable geometry at mem_idx={}: prev (k={}, d={}, ell={}, lanes={}) vs cur (k={}, d={}, ell={}, lanes={})", + i_mem, + prev_inst.k, + prev_inst.d, + prev_inst.ell, + prev_inst.lanes, + mem_inst.k, + mem_inst.d, + mem_inst.ell, + mem_inst.lanes + ))); + } + let prev_decoded = prev_twist_decoded + .ok_or_else(|| PiCcsError::ProtocolError("missing prev Twist decoded cols".into()))? + .get(i_mem) + .ok_or_else(|| PiCcsError::ProtocolError("missing prev Twist decoded cols at mem_idx".into()))?; + if prev_decoded.lanes.is_empty() { + return Err(PiCcsError::ProtocolError( + "missing prev Twist decoded cols lanes".into(), + )); + } + + let mut prev_total_oracles: Vec> = Vec::with_capacity(prev_decoded.lanes.len()); + let mut claimed_prev_total = K::ZERO; + for lane in prev_decoded.lanes.iter() { + let (oracle, claim) = TwistTotalIncOracleSparseTime::new( + lane.wa_bits.clone(), + lane.has_write.clone(), + lane.inc_at_write_addr.clone(), + r_addr, + ); + prev_total_oracles.push(Box::new(oracle)); + claimed_prev_total += claim; + } + let oracle_prev_total: Box = Box::new(SumRoundOracle::new(prev_total_oracles)); + + val_oracles.push(oracle_prev_total); + bind_claims.push((plan.bind_tags[claim_idx], claimed_prev_total)); + claimed_sums.push(claimed_prev_total); + claim_idx += 1; + + claimed_prev_inc_sums_total.push(Some(claimed_prev_total)); + } else { + claimed_prev_inc_sums_total.push(None); + } + } + + tr.append_message( + b"twist/val_eval/batch_start", + &(step.mem_instances.len() as u64).to_le_bytes(), + ); + tr.append_message(b"twist/val_eval/step_idx", &(step_idx as u64).to_le_bytes()); + bind_twist_val_eval_claim_sums(tr, &bind_claims); + + let mut claims: Vec> = val_oracles + .iter_mut() + .zip(claimed_sums.iter()) + .zip(plan.labels.iter()) + .map(|((oracle, sum), label)| BatchedClaim { + oracle: oracle.as_mut(), + claimed_sum: *sum, + label: *label, + }) + .collect(); + + let (r_val_out, per_claim_results) = + run_batched_sumcheck_prover_ds(tr, b"twist/val_eval_batch", step_idx, claims.as_mut_slice())?; + + if per_claim_results.len() != claim_count { + return Err(PiCcsError::ProtocolError(format!( + "twist val-eval results count mismatch (expected {}, got {})", + claim_count, + per_claim_results.len() + ))); + } + if r_val_out.len() != r_time.len() { + return Err(PiCcsError::ProtocolError(format!( + "twist val-eval r_val.len()={}, expected ell_n={}", + r_val_out.len(), + r_time.len() + ))); + } + r_val = r_val_out; + + for i in 0..n_mem { + let base = claims_per_mem * i; + twist_val_eval_proofs.push(twist::TwistValEvalProof { + claimed_inc_sum_lt: claimed_inc_sums_lt[i], + rounds_lt: per_claim_results[base].round_polys.clone(), + claimed_inc_sum_total: claimed_inc_sums_total[i], + rounds_total: per_claim_results[base + 1].round_polys.clone(), + claimed_prev_inc_sum_total: claimed_prev_inc_sums_total[i], + rounds_prev_total: has_prev.then(|| per_claim_results[base + 2].round_polys.clone()), + }); + } + + tr.append_message(b"twist/val_eval/batch_done", &[]); + } + + if step.lut_instances.is_empty() { + if !shout_addr_pre.claimed_sums.is_empty() || !shout_addr_pre.groups.is_empty() { + return Err(PiCcsError::ProtocolError( + "shout_addr_pre must be empty when there are no Shout instances".into(), + )); + } + } + + for _ in 0..step.lut_instances.len() { + proofs.push(MemOrLutProof::Shout(ShoutProofK::default())); + } + + for idx in 0..step.mem_instances.len() { + let mut proof = TwistProofK::default(); + proof.addr_pre = twist_pre + .get(idx) + .ok_or_else(|| PiCcsError::ProtocolError("missing Twist addr_pre".into()))? + .addr_pre + .clone(); + proof.val_eval = twist_val_eval_proofs.get(idx).cloned(); + + proofs.push(MemOrLutProof::Twist(proof)); + } + + if !step.mem_instances.is_empty() { + if r_val.len() != r_time.len() { + return Err(PiCcsError::ProtocolError(format!( + "twist val-eval r_val.len()={}, expected ell_n={}", + r_val.len(), + r_time.len() + ))); + } + + let core_t = s.t(); + + // Shared-bus mode: val-lane checks read bus openings from CPU ME claims at r_val. + // Emit CPU ME at r_val for current step (and previous step for rollover). + let (mcs_inst, mcs_wit) = &step.mcs; + let cpu_claims_cur = ts::emit_me_claims_for_mats( + tr, + b"cpu_bus/me_digest_val", + params, + s, + core::slice::from_ref(&mcs_inst.c), + core::slice::from_ref(&mcs_wit.Z), + &r_val, + m_in, + )?; + if cpu_claims_cur.len() != 1 { + return Err(PiCcsError::ProtocolError(format!( + "expected exactly 1 CPU ME claim at r_val, got {}", + cpu_claims_cur.len() + ))); + } + let mut cpu_claims_cur = cpu_claims_cur; + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( + params, + cpu_bus, + core_t, + &mcs_wit.Z, + &mut cpu_claims_cur[0], + )?; + val_me_claims.extend(cpu_claims_cur); + + if let Some(prev) = prev_step { + let (prev_mcs_inst, prev_mcs_wit) = &prev.mcs; + let cpu_claims_prev = ts::emit_me_claims_for_mats( + tr, + b"cpu_bus/me_digest_val", + params, + s, + core::slice::from_ref(&prev_mcs_inst.c), + core::slice::from_ref(&prev_mcs_wit.Z), + &r_val, + m_in, + )?; + if cpu_claims_prev.len() != 1 { + return Err(PiCcsError::ProtocolError(format!( + "expected exactly 1 prev CPU ME claim at r_val, got {}", + cpu_claims_prev.len() + ))); + } + let mut cpu_claims_prev = cpu_claims_prev; + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( + params, + cpu_bus, + core_t, + &prev_mcs_wit.Z, + &mut cpu_claims_prev[0], + )?; + val_me_claims.extend(cpu_claims_prev); + } + } + + if step.mem_instances.is_empty() { + if !twist_val_eval_proofs.is_empty() { + return Err(PiCcsError::ProtocolError( + "twist val-eval proofs must be empty when no mem instances are present".into(), + )); + } + if !r_val.is_empty() { + return Err(PiCcsError::ProtocolError( + "twist r_val must be empty when no mem instances are present".into(), + )); + } + if !val_me_claims.is_empty() { + return Err(PiCcsError::ProtocolError( + "twist val-lane ME claims must be empty when no mem instances are present".into(), + )); + } + } else if val_me_claims.is_empty() { + return Err(PiCcsError::ProtocolError( + "twist val-eval requires non-empty val-lane ME claims".into(), + )); + } + + let (wb_claims, wp_claims) = emit_route_a_wb_wp_me_claims(tr, params, s, step, r_time)?; + wb_me_claims.extend(wb_claims); + wp_me_claims.extend(wp_claims); + + Ok(MemSidecarProof { + val_me_claims, + wb_me_claims, + wp_me_claims, + shout_addr_pre: shout_addr_pre.clone(), + proofs, + }) +} + +// ============================================================================ +// ============================================================================ diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_oracles.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_oracles.rs new file mode 100644 index 00000000..a2c2cdcc --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_oracles.rs @@ -0,0 +1,1472 @@ +use super::*; + +pub(crate) fn build_route_a_memory_oracles( + params: &NeoParams, + step: &StepWitnessBundle, + ell_n: usize, + r_cycle: &[K], + shout_pre: &ShoutAddrPreBatchProverData, + twist_pre: &[TwistAddrPreProverData], +) -> Result { + if ell_n != r_cycle.len() { + return Err(PiCcsError::InvalidInput(format!( + "Route A: ell_n mismatch (ell_n={ell_n}, r_cycle.len()={})", + r_cycle.len() + ))); + } + if shout_pre.decoded.len() != step.lut_instances.len() { + return Err(PiCcsError::InvalidInput(format!( + "shout pre-time count mismatch (expected {}, got {})", + step.lut_instances.len(), + shout_pre.decoded.len() + ))); + } + if twist_pre.len() != step.mem_instances.len() { + return Err(PiCcsError::InvalidInput(format!( + "twist pre-time decoded count mismatch (expected {}, got {})", + step.mem_instances.len(), + twist_pre.len() + ))); + } + + let (event_alpha, event_beta, event_gamma, shout_event_trace_hash) = + build_event_table_shout_context(params, step, ell_n, r_cycle)?; + + let mut shout_oracles = Vec::with_capacity(step.lut_instances.len()); + let shout_gamma_specs = + RouteATimeClaimPlan::derive_shout_gamma_groups_for_instances(step.lut_instances.iter().map(|(inst, _)| inst)); + let mut shout_lane_to_gamma: std::collections::HashMap<(usize, usize), usize> = std::collections::HashMap::new(); + for (g_idx, g) in shout_gamma_specs.iter().enumerate() { + for lane in g.lanes.iter() { + shout_lane_to_gamma.insert((lane.inst_idx, lane.lane_idx), g_idx); + } + } + let mut r_addr_by_ell: std::collections::BTreeMap = std::collections::BTreeMap::new(); + for g in shout_pre.addr_pre.groups.iter() { + r_addr_by_ell.insert(g.ell_addr, g.r_addr.as_slice()); + } + for (lut_idx, ((lut_inst, _lut_wit), decoded)) in step + .lut_instances + .iter() + .zip(shout_pre.decoded.iter()) + .enumerate() + { + let ell_addr = lut_inst.d * lut_inst.ell; + let ell_addr_u32 = u32::try_from(ell_addr) + .map_err(|_| PiCcsError::InvalidInput("Shout(Route A): ell_addr overflows u32".into()))?; + let r_addr = *r_addr_by_ell + .get(&ell_addr_u32) + .ok_or_else(|| PiCcsError::ProtocolError("missing shout addr-pre group r_addr".into()))?; + if r_addr.len() != ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "Shout(Route A): r_addr.len()={} != ell_addr={}", + r_addr.len(), + ell_addr + ))); + } + + if decoded.lanes.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "Shout(Route A): decoded lanes empty at lut_idx={lut_idx}" + ))); + } + + let lane_count = decoded.lanes.len(); + let mut lanes: Vec = Vec::with_capacity(lane_count); + + let packed_layout = rv32_packed_shout_layout(&lut_inst.table_spec)?; + let packed_op = packed_layout.map(|(op, _time_bits)| op); + let packed_time_bits = packed_layout.map(|(_op, time_bits)| time_bits).unwrap_or(0); + let is_packed = packed_op.is_some(); + if packed_time_bits != 0 && packed_time_bits != ell_n { + return Err(PiCcsError::InvalidInput(format!( + "event-table Shout expects time_bits == ell_n (time_bits={packed_time_bits}, ell_n={ell_n})" + ))); + } + + for (lane_idx, lane) in decoded.lanes.iter().enumerate() { + let gamma_group = shout_lane_to_gamma.get(&(lut_idx, lane_idx)).copied(); + if let Some(op) = packed_op { + let time_bits = packed_time_bits; + let packed_cols: &[SparseIdxVec] = lane.addr_bits.get(time_bits..).ok_or_else(|| { + PiCcsError::InvalidInput("packed RV32: addr_bits too short for time_bits prefix".into()) + })?; + let lhs = packed_cols + .get(0) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32: missing lhs column".into()))? + .clone(); + let rhs = packed_cols + .get(1) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32: missing rhs column".into()))? + .clone(); + + // Packed bitwise (AND/OR/XOR): base-4 digit decomposition. + let (bitwise_lhs_digits, bitwise_rhs_digits) = match op { + Rv32PackedShoutOp::And + | Rv32PackedShoutOp::Andn + | Rv32PackedShoutOp::Or + | Rv32PackedShoutOp::Xor => { + if packed_cols.len() != 34 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 bitwise: expected ell_addr=34, got {}", + packed_cols.len() + ))); + } + let lhs_digits: Vec> = packed_cols.iter().skip(2).take(16).cloned().collect(); + let rhs_digits: Vec> = packed_cols.iter().skip(18).take(16).cloned().collect(); + if lhs_digits.len() != 16 || rhs_digits.len() != 16 { + return Err(PiCcsError::ProtocolError( + "packed RV32 bitwise: digit slice length mismatch".into(), + )); + } + (lhs_digits, rhs_digits) + } + _ => (Vec::new(), Vec::new()), + }; + + let value_oracle: Box = match op { + Rv32PackedShoutOp::And => Box::new(Rv32PackedAndOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + bitwise_lhs_digits.clone(), + bitwise_rhs_digits.clone(), + lane.val.clone(), + )), + Rv32PackedShoutOp::Andn => Box::new(Rv32PackedAndnOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + bitwise_lhs_digits.clone(), + bitwise_rhs_digits.clone(), + lane.val.clone(), + )), + Rv32PackedShoutOp::Add => Box::new(Rv32PackedAddOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 ADD: missing carry column".into()))? + .clone(), + lane.val.clone(), + )), + Rv32PackedShoutOp::Or => Box::new(Rv32PackedOrOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + bitwise_lhs_digits.clone(), + bitwise_rhs_digits.clone(), + lane.val.clone(), + )), + Rv32PackedShoutOp::Sub => Box::new(Rv32PackedSubOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SUB: missing borrow column".into()))? + .clone(), + lane.val.clone(), + )), + Rv32PackedShoutOp::Xor => Box::new(Rv32PackedXorOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + bitwise_lhs_digits.clone(), + bitwise_rhs_digits.clone(), + lane.val.clone(), + )), + Rv32PackedShoutOp::Eq => Box::new(Rv32PackedEqOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + { + let diff_bits: Vec> = packed_cols.iter().skip(3).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 EQ: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + diff_bits + }, + lane.val.clone(), + )), + Rv32PackedShoutOp::Neq => Box::new(Rv32PackedNeqOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + { + let diff_bits: Vec> = packed_cols.iter().skip(3).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 NEQ: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + diff_bits + }, + lane.val.clone(), + )), + Rv32PackedShoutOp::Mul => { + let carry_bits: Vec> = packed_cols.iter().skip(2).cloned().collect(); + if carry_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 MUL: expected 32 carry bits, got {}", + carry_bits.len() + ))); + } + Box::new(Rv32PackedMulOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + carry_bits, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Mulhu => { + let lo_bits: Vec> = packed_cols.iter().skip(2).cloned().collect(); + if lo_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 MULHU: expected 32 lo bits, got {}", + lo_bits.len() + ))); + } + Box::new(Rv32PackedMulhuOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + lo_bits, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Mulh => { + let hi = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULH: missing hi opening".into()))? + .clone(); + let lo_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if lo_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 MULH: expected 32 lo bits, got {}", + lo_bits.len() + ))); + } + Box::new(Rv32PackedMulHiOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + lo_bits, + hi, + )) + } + Rv32PackedShoutOp::Mulhsu => { + let hi = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULHSU: missing hi opening".into()))? + .clone(); + let lo_bits: Vec> = packed_cols.iter().skip(5).cloned().collect(); + if lo_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 MULHSU: expected 32 lo bits, got {}", + lo_bits.len() + ))); + } + Box::new(Rv32PackedMulHiOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + lo_bits, + hi, + )) + } + Rv32PackedShoutOp::Slt => { + let lhs_sign = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SLT: missing lhs_sign bit".into()))? + .clone(); + let rhs_sign = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SLT: missing rhs_sign bit".into()))? + .clone(); + let diff = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SLT: missing diff opening".into()))? + .clone(); + Box::new(Rv32PackedSltOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + lhs_sign, + rhs_sign, + diff, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Divu => { + let rem = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIVU: missing rem opening".into()))? + .clone(); + let rhs_is_zero = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIVU: missing rhs_is_zero".into()))? + .clone(); + Box::new(Rv32PackedDivuOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + rem, + rhs_is_zero, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Remu => { + let quot = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REMU: missing quot opening".into()))? + .clone(); + let rhs_is_zero = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REMU: missing rhs_is_zero".into()))? + .clone(); + Box::new(Rv32PackedRemuOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + quot, + rhs_is_zero, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Div => { + let rhs_is_zero = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing rhs_is_zero".into()))? + .clone(); + let lhs_sign = packed_cols + .get(6) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing lhs_sign".into()))? + .clone(); + let rhs_sign = packed_cols + .get(7) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing rhs_sign".into()))? + .clone(); + let q_abs = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing q_abs".into()))? + .clone(); + let q_is_zero = packed_cols + .get(9) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing q_is_zero".into()))? + .clone(); + Box::new(Rv32PackedDivOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs_sign, + rhs_sign, + rhs_is_zero, + q_abs, + q_is_zero, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Rem => { + let rhs_is_zero = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing rhs_is_zero".into()))? + .clone(); + let lhs_sign = packed_cols + .get(6) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing lhs_sign".into()))? + .clone(); + let r_abs = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing r_abs".into()))? + .clone(); + let r_is_zero = packed_cols + .get(9) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing r_is_zero".into()))? + .clone(); + Box::new(Rv32PackedRemOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + lhs_sign, + rhs_is_zero, + r_abs, + r_is_zero, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Sll => { + let shamt_bits: Vec> = packed_cols.iter().skip(1).take(5).cloned().collect(); + if shamt_bits.len() != 5 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SLL: expected 5 shamt bits, got {}", + shamt_bits.len() + ))); + } + let carry_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if carry_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SLL: expected 32 carry bits, got {}", + carry_bits.len() + ))); + } + Box::new(Rv32PackedSllOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + shamt_bits, + carry_bits, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Srl => { + let shamt_bits: Vec> = packed_cols.iter().skip(1).take(5).cloned().collect(); + if shamt_bits.len() != 5 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRL: expected 5 shamt bits, got {}", + shamt_bits.len() + ))); + } + let rem_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if rem_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRL: expected 32 rem bits, got {}", + rem_bits.len() + ))); + } + Box::new(Rv32PackedSrlOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + shamt_bits, + rem_bits, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Sra => { + let shamt_bits: Vec> = packed_cols.iter().skip(1).take(5).cloned().collect(); + if shamt_bits.len() != 5 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRA: expected 5 shamt bits, got {}", + shamt_bits.len() + ))); + } + let sign = packed_cols + .get(6) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SRA: missing sign bit".into()))? + .clone(); + let rem_bits: Vec> = packed_cols.iter().skip(7).cloned().collect(); + if rem_bits.len() != 31 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRA: expected 31 rem bits, got {}", + rem_bits.len() + ))); + } + Box::new(Rv32PackedSraOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + shamt_bits, + sign, + rem_bits, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Sltu => Box::new(Rv32PackedSltuOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs.clone(), + rhs.clone(), + packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SLTU: missing diff opening".into()))? + .clone(), + lane.val.clone(), + )), + }; + let adapter_oracle: Box = match op { + Rv32PackedShoutOp::And + | Rv32PackedShoutOp::Andn + | Rv32PackedShoutOp::Or + | Rv32PackedShoutOp::Xor => { + let weights = bitness_weights(r_cycle, 34, 0x4249_5457_4F50u64 + lut_idx as u64); + Box::new(Rv32PackedBitwiseAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs, + rhs, + bitwise_lhs_digits, + bitwise_rhs_digits, + weights, + )) + } + Rv32PackedShoutOp::Add + | Rv32PackedShoutOp::Sub + | Rv32PackedShoutOp::Sll + | Rv32PackedShoutOp::Mul + | Rv32PackedShoutOp::Mulhu => Box::new(ZeroOracleSparseTime::new(r_cycle.len(), 2)), + Rv32PackedShoutOp::Mulh => { + let hi = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULH: missing hi opening".into()))? + .clone(); + let lhs_sign = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULH: missing lhs_sign".into()))? + .clone(); + let rhs_sign = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULH: missing rhs_sign".into()))? + .clone(); + let k = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULH: missing k opening".into()))? + .clone(); + let weights = bitness_weights(r_cycle, 2, 0x4D55_4C48_4144_5054u64 + lut_idx as u64); + let w = [weights[0], weights[1]]; + Box::new(Rv32PackedMulhAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs, + rhs, + lhs_sign, + rhs_sign, + hi, + k, + lane.val.clone(), + w, + )) + } + Rv32PackedShoutOp::Mulhsu => { + let hi = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULHSU: missing hi opening".into()))? + .clone(); + let lhs_sign = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULHSU: missing lhs_sign".into()))? + .clone(); + let borrow = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULHSU: missing borrow".into()))? + .clone(); + Box::new(Rv32PackedMulhsuAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs, + rhs, + lhs_sign, + hi, + borrow, + lane.val.clone(), + )) + } + Rv32PackedShoutOp::Divu => { + let rem = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIVU: missing rem opening".into()))? + .clone(); + let rhs_is_zero = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIVU: missing rhs_is_zero".into()))? + .clone(); + let diff = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIVU: missing diff".into()))? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 DIVU: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + let weights = bitness_weights(r_cycle, 4, 0x4449_5655_4144_5054u64 + lut_idx as u64); + let w = [weights[0], weights[1], weights[2], weights[3]]; + Box::new(Rv32PackedDivRemuAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + rhs, + rhs_is_zero, + rem, + diff, + diff_bits, + w, + )) + } + Rv32PackedShoutOp::Remu => { + let rhs_is_zero = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REMU: missing rhs_is_zero".into()))? + .clone(); + let diff = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REMU: missing diff".into()))? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 REMU: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + let weights = bitness_weights(r_cycle, 4, 0x4449_5655_4144_5054u64 + lut_idx as u64); + let w = [weights[0], weights[1], weights[2], weights[3]]; + Box::new(Rv32PackedDivRemuAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + rhs, + rhs_is_zero, + lane.val.clone(), + diff, + diff_bits, + w, + )) + } + Rv32PackedShoutOp::Div => { + let rhs_is_zero = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing rhs_is_zero".into()))? + .clone(); + let lhs_sign = packed_cols + .get(6) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing lhs_sign".into()))? + .clone(); + let rhs_sign = packed_cols + .get(7) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing rhs_sign".into()))? + .clone(); + let q_abs = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing q_abs".into()))? + .clone(); + let r_abs = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing r_abs".into()))? + .clone(); + let q_is_zero = packed_cols + .get(9) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing q_is_zero".into()))? + .clone(); + let diff = packed_cols + .get(10) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing diff".into()))? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(11).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 DIV: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + let weights = bitness_weights(r_cycle, 7, 0x4449_565F_4144_5054u64 + lut_idx as u64); + let w = [ + weights[0], weights[1], weights[2], weights[3], weights[4], weights[5], weights[6], + ]; + Box::new(Rv32PackedDivRemAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs, + rhs, + rhs_is_zero, + lhs_sign, + rhs_sign, + q_abs.clone(), + r_abs, + q_abs, + q_is_zero, + diff, + diff_bits, + w, + )) + } + Rv32PackedShoutOp::Rem => { + let rhs_is_zero = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing rhs_is_zero".into()))? + .clone(); + let lhs_sign = packed_cols + .get(6) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing lhs_sign".into()))? + .clone(); + let rhs_sign = packed_cols + .get(7) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing rhs_sign".into()))? + .clone(); + let q_abs = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing q_abs".into()))? + .clone(); + let r_abs = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing r_abs".into()))? + .clone(); + let r_is_zero = packed_cols + .get(9) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing r_is_zero".into()))? + .clone(); + let diff = packed_cols + .get(10) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing diff".into()))? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(11).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 REM: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + let weights = bitness_weights(r_cycle, 7, 0x4449_565F_4144_5054u64 + lut_idx as u64); + let w = [ + weights[0], weights[1], weights[2], weights[3], weights[4], weights[5], weights[6], + ]; + Box::new(Rv32PackedDivRemAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs, + rhs, + rhs_is_zero, + lhs_sign, + rhs_sign, + q_abs, + r_abs.clone(), + r_abs, + r_is_zero, + diff, + diff_bits, + w, + )) + } + Rv32PackedShoutOp::Slt => { + let diff_bits: Vec> = packed_cols.iter().skip(5).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SLT: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + Box::new(U32DecompOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + packed_cols + .get(2) + .ok_or_else(|| { + PiCcsError::InvalidInput("packed RV32 SLT: missing diff opening".into()) + })? + .clone(), + diff_bits, + )) + } + Rv32PackedShoutOp::Srl => { + let shamt_bits: Vec> = packed_cols.iter().skip(1).take(5).cloned().collect(); + if shamt_bits.len() != 5 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRL: expected 5 shamt bits, got {}", + shamt_bits.len() + ))); + } + let rem_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if rem_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRL: expected 32 rem bits, got {}", + rem_bits.len() + ))); + } + Box::new(Rv32PackedSrlAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + shamt_bits, + rem_bits, + )) + } + Rv32PackedShoutOp::Sra => { + let shamt_bits: Vec> = packed_cols.iter().skip(1).take(5).cloned().collect(); + if shamt_bits.len() != 5 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRA: expected 5 shamt bits, got {}", + shamt_bits.len() + ))); + } + let rem_bits: Vec> = packed_cols.iter().skip(7).cloned().collect(); + if rem_bits.len() != 31 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRA: expected 31 rem bits, got {}", + rem_bits.len() + ))); + } + Box::new(Rv32PackedSraAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + shamt_bits, + rem_bits, + )) + } + Rv32PackedShoutOp::Eq => Box::new(Rv32PackedEqAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs, + rhs, + packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 EQ: missing borrow bit".into()))? + .clone(), + { + let diff_bits: Vec> = packed_cols.iter().skip(3).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 EQ: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + diff_bits + }, + )), + Rv32PackedShoutOp::Neq => Box::new(Rv32PackedNeqAdapterOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + lhs, + rhs, + packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 NEQ: missing borrow bit".into()))? + .clone(), + { + let diff_bits: Vec> = packed_cols.iter().skip(3).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 NEQ: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + diff_bits + }, + )), + Rv32PackedShoutOp::Sltu => { + let diff_bits: Vec> = packed_cols.iter().skip(3).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SLTU: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + Box::new(U32DecompOracleSparseTime::new( + r_cycle, + lane.has_lookup.clone(), + packed_cols + .get(2) + .ok_or_else(|| { + PiCcsError::InvalidInput("packed RV32 SLTU: missing diff opening".into()) + })? + .clone(), + diff_bits, + )) + } + }; + + let (event_table_hash, event_table_hash_claim) = if time_bits > 0 { + let time_bits_cols: Vec> = lane.addr_bits.iter().take(time_bits).cloned().collect(); + + let lhs_col = packed_cols + .get(0) + .ok_or_else(|| PiCcsError::InvalidInput("event-table hash: missing lhs".into()))? + .clone(); + + let rhs_terms: Vec<(SparseIdxVec, K)> = match op { + Rv32PackedShoutOp::Sll | Rv32PackedShoutOp::Srl | Rv32PackedShoutOp::Sra => { + let mut out: Vec<(SparseIdxVec, K)> = Vec::with_capacity(5); + for i in 0..5usize { + let b = packed_cols + .get(1 + i) + .ok_or_else(|| { + PiCcsError::InvalidInput("event-table hash: missing shamt bit".into()) + })? + .clone(); + out.push((b, K::from(F::from_u64(1u64 << i)))); + } + out + } + _ => vec![( + packed_cols + .get(1) + .ok_or_else(|| PiCcsError::InvalidInput("event-table hash: missing rhs".into()))? + .clone(), + K::ONE, + )], + }; + + let (oracle, claim) = ShoutEventTableHashOracleSparseTime::new( + &r_cycle[..time_bits], + time_bits_cols, + lane.has_lookup.clone(), + lane.val.clone(), + lhs_col, + rhs_terms, + event_alpha, + event_beta, + event_gamma, + ); + (Some(Box::new(oracle) as Box), Some(claim)) + } else { + (None, None) + }; + + lanes.push(RouteAShoutTimeLaneOracles { + value: value_oracle, + // Enforce correctness: claim must be 0. + value_claim: K::ZERO, + adapter: adapter_oracle, + adapter_claim: K::ZERO, + event_table_hash, + event_table_hash_claim, + gamma_group: None, + }); + } else { + let (value_oracle, value_claim) = + ShoutValueOracleSparse::new(r_cycle, lane.has_lookup.clone(), lane.val.clone()); + + let (adapter_oracle, adapter_claim) = IndexAdapterOracleSparseTime::new_with_gate( + r_cycle, + lane.has_lookup.clone(), + lane.addr_bits.clone(), + r_addr, + ); + + lanes.push(RouteAShoutTimeLaneOracles { + value: Box::new(value_oracle), + value_claim, + adapter: Box::new(adapter_oracle), + adapter_claim, + event_table_hash: None, + event_table_hash_claim: None, + gamma_group, + }); + } + } + + let bitness: Vec> = if is_packed { + // Packed RV32: boolean columns depend on the packed op. + let mut bit_cols: Vec> = Vec::new(); + for lane in decoded.lanes.iter() { + // Event-table packed: time bits must be boolean. + if packed_time_bits > 0 { + bit_cols.extend(lane.addr_bits.iter().take(packed_time_bits).cloned()); + } + let packed_cols: &[SparseIdxVec] = lane + .addr_bits + .get(packed_time_bits..) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32: missing packed cols".into()))?; + match packed_op { + Some( + Rv32PackedShoutOp::And + | Rv32PackedShoutOp::Andn + | Rv32PackedShoutOp::Or + | Rv32PackedShoutOp::Xor, + ) => { + bit_cols.push(lane.has_lookup.clone()); + } + Some(Rv32PackedShoutOp::Add | Rv32PackedShoutOp::Sub) => { + let aux = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32: missing aux column".into()))? + .clone(); + bit_cols.push(aux); + bit_cols.push(lane.has_lookup.clone()); + } + Some(Rv32PackedShoutOp::Eq | Rv32PackedShoutOp::Neq) => { + let borrow = packed_cols + .get(2) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 EQ/NEQ: missing borrow bit".into()))? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(3).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 EQ/NEQ: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.push(lane.val.clone()); + bit_cols.push(borrow); + bit_cols.extend(diff_bits); + } + Some(Rv32PackedShoutOp::Mul) => { + let carry_bits: Vec> = packed_cols.iter().skip(2).cloned().collect(); + if carry_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 MUL: expected 32 carry bits, got {}", + carry_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.extend(carry_bits); + } + Some(Rv32PackedShoutOp::Mulhu) => { + let lo_bits: Vec> = packed_cols.iter().skip(2).cloned().collect(); + if lo_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 MULHU: expected 32 lo bits, got {}", + lo_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.extend(lo_bits); + } + Some(Rv32PackedShoutOp::Mulh) => { + let lhs_sign = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULH: missing lhs_sign bit".into()))? + .clone(); + let rhs_sign = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULH: missing rhs_sign bit".into()))? + .clone(); + let lo_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if lo_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 MULH: expected 32 lo bits, got {}", + lo_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.push(lhs_sign); + bit_cols.push(rhs_sign); + bit_cols.extend(lo_bits); + } + Some(Rv32PackedShoutOp::Mulhsu) => { + let lhs_sign = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULHSU: missing lhs_sign bit".into()))? + .clone(); + let borrow = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 MULHSU: missing borrow bit".into()))? + .clone(); + let lo_bits: Vec> = packed_cols.iter().skip(5).cloned().collect(); + if lo_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 MULHSU: expected 32 lo bits, got {}", + lo_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.push(lhs_sign); + bit_cols.push(borrow); + bit_cols.extend(lo_bits); + } + Some(Rv32PackedShoutOp::Slt) => { + let lhs_sign = packed_cols + .get(3) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SLT: missing lhs_sign bit".into()))? + .clone(); + let rhs_sign = packed_cols + .get(4) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SLT: missing rhs_sign bit".into()))? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(5).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SLT: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + bit_cols.push(lane.val.clone()); + bit_cols.push(lane.has_lookup.clone()); + bit_cols.push(lhs_sign); + bit_cols.push(rhs_sign); + bit_cols.extend(diff_bits); + } + Some(Rv32PackedShoutOp::Sll) => { + let shamt_bits: Vec> = packed_cols.iter().skip(1).take(5).cloned().collect(); + if shamt_bits.len() != 5 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SLL: expected 5 shamt bits, got {}", + shamt_bits.len() + ))); + } + let carry_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if carry_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SLL: expected 32 carry bits, got {}", + carry_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.extend(shamt_bits); + bit_cols.extend(carry_bits); + } + Some(Rv32PackedShoutOp::Srl) => { + let shamt_bits: Vec> = packed_cols.iter().skip(1).take(5).cloned().collect(); + if shamt_bits.len() != 5 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRL: expected 5 shamt bits, got {}", + shamt_bits.len() + ))); + } + let rem_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if rem_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRL: expected 32 rem bits, got {}", + rem_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.extend(shamt_bits); + bit_cols.extend(rem_bits); + } + Some(Rv32PackedShoutOp::Sra) => { + let shamt_bits: Vec> = packed_cols.iter().skip(1).take(5).cloned().collect(); + if shamt_bits.len() != 5 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRA: expected 5 shamt bits, got {}", + shamt_bits.len() + ))); + } + let sign = packed_cols + .get(6) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 SRA: missing sign bit".into()))? + .clone(); + let rem_bits: Vec> = packed_cols.iter().skip(7).cloned().collect(); + if rem_bits.len() != 31 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SRA: expected 31 rem bits, got {}", + rem_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.extend(shamt_bits); + bit_cols.push(sign); + bit_cols.extend(rem_bits); + } + Some(Rv32PackedShoutOp::Sltu) => { + let diff_bits: Vec> = packed_cols.iter().skip(3).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 SLTU: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + bit_cols.push(lane.val.clone()); + bit_cols.push(lane.has_lookup.clone()); + bit_cols.extend(diff_bits); + } + Some(Rv32PackedShoutOp::Divu | Rv32PackedShoutOp::Remu) => { + let rhs_is_zero = packed_cols + .get(4) + .ok_or_else(|| { + PiCcsError::InvalidInput("packed RV32 DIVU/REMU: missing rhs_is_zero".into()) + })? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(6).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 DIVU/REMU: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.push(rhs_is_zero); + bit_cols.extend(diff_bits); + } + Some(Rv32PackedShoutOp::Div) => { + let rhs_is_zero = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing rhs_is_zero".into()))? + .clone(); + let lhs_sign = packed_cols + .get(6) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing lhs_sign".into()))? + .clone(); + let rhs_sign = packed_cols + .get(7) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing rhs_sign".into()))? + .clone(); + let q_is_zero = packed_cols + .get(9) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 DIV: missing q_is_zero".into()))? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(11).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 DIV: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.push(rhs_is_zero); + bit_cols.push(lhs_sign); + bit_cols.push(rhs_sign); + bit_cols.push(q_is_zero); + bit_cols.extend(diff_bits); + } + Some(Rv32PackedShoutOp::Rem) => { + let rhs_is_zero = packed_cols + .get(5) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing rhs_is_zero".into()))? + .clone(); + let lhs_sign = packed_cols + .get(6) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing lhs_sign".into()))? + .clone(); + let rhs_sign = packed_cols + .get(7) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing rhs_sign".into()))? + .clone(); + let r_is_zero = packed_cols + .get(9) + .ok_or_else(|| PiCcsError::InvalidInput("packed RV32 REM: missing r_is_zero".into()))? + .clone(); + let diff_bits: Vec> = packed_cols.iter().skip(11).cloned().collect(); + if diff_bits.len() != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RV32 REM: expected 32 diff bits, got {}", + diff_bits.len() + ))); + } + bit_cols.push(lane.has_lookup.clone()); + bit_cols.push(rhs_is_zero); + bit_cols.push(lhs_sign); + bit_cols.push(rhs_sign); + bit_cols.push(r_is_zero); + bit_cols.extend(diff_bits); + } + None => { + return Err(PiCcsError::ProtocolError( + "packed_op drift: is_packed=true but packed_op=None".into(), + )); + } + } + } + let weights = bitness_weights(r_cycle, bit_cols.len(), 0x5348_4F55_54u64 + lut_idx as u64); + let bitness_oracle = LazyWeightedBitnessOracleSparseTime::new_with_cycle(r_cycle, bit_cols, weights); + vec![Box::new(bitness_oracle)] + } else { + let mut bit_cols: Vec> = Vec::with_capacity(lane_count * (ell_addr + 1)); + for lane in decoded.lanes.iter() { + bit_cols.extend(lane.addr_bits.iter().cloned()); + bit_cols.push(lane.has_lookup.clone()); + } + let weights = bitness_weights(r_cycle, bit_cols.len(), 0x5348_4F55_54u64 + lut_idx as u64); + let bitness_oracle = LazyWeightedBitnessOracleSparseTime::new_with_cycle(r_cycle, bit_cols, weights); + vec![Box::new(bitness_oracle)] + }; + + shout_oracles.push(RouteAShoutTimeOracles { lanes, bitness }); + } + + let mut shout_gamma_groups = Vec::with_capacity(shout_gamma_specs.len()); + for (g_idx, g) in shout_gamma_specs.iter().enumerate() { + let mut value_cols: Vec> = Vec::with_capacity(g.lanes.len() * 2); + let mut adapter_cols: Vec> = Vec::with_capacity(g.lanes.len() * (1 + g.ell_addr)); + let weights = bitness_weights(r_cycle, g.lanes.len(), 0x5348_5F47_414D_4Du64 ^ g.key); + let mut weighted_table: Vec = Vec::with_capacity(g.lanes.len()); + let mut group_r_addr: Option> = None; + let mut value_claim = K::ZERO; + let mut adapter_claim = K::ZERO; + + for (slot, lane_ref) in g.lanes.iter().enumerate() { + let (lut_inst, _lut_wit) = step + .lut_instances + .get(lane_ref.inst_idx) + .ok_or_else(|| PiCcsError::ProtocolError("shout gamma group inst idx drift".into()))?; + let decoded = shout_pre + .decoded + .get(lane_ref.inst_idx) + .ok_or_else(|| PiCcsError::ProtocolError("shout gamma decoded inst idx drift".into()))?; + let lane = decoded + .lanes + .get(lane_ref.lane_idx) + .ok_or_else(|| PiCcsError::ProtocolError("shout gamma decoded lane idx drift".into()))?; + let lane_oracles = shout_oracles + .get(lane_ref.inst_idx) + .and_then(|o| o.lanes.get(lane_ref.lane_idx)) + .ok_or_else(|| PiCcsError::ProtocolError("shout gamma lane oracle idx drift".into()))?; + if lane_oracles.gamma_group != Some(g_idx) { + return Err(PiCcsError::ProtocolError( + "shout gamma grouping mismatch between plan and oracle wiring".into(), + )); + } + let ell_addr = lut_inst.d * lut_inst.ell; + if ell_addr != g.ell_addr { + return Err(PiCcsError::ProtocolError("shout gamma group ell_addr mismatch".into())); + } + let ell_addr_u32 = u32::try_from(ell_addr) + .map_err(|_| PiCcsError::InvalidInput("shout gamma ell_addr overflows u32".into()))?; + let r_addr = *r_addr_by_ell + .get(&ell_addr_u32) + .ok_or_else(|| PiCcsError::ProtocolError("missing shout gamma group r_addr".into()))?; + if let Some(prev) = group_r_addr.as_ref() { + if prev.as_slice() != r_addr { + return Err(PiCcsError::ProtocolError( + "shout gamma group r_addr mismatch across lanes".into(), + )); + } + } else { + group_r_addr = Some(r_addr.to_vec()); + } + + let table_eval_at_r_addr = match &lut_inst.table_spec { + Some(spec) => spec.eval_table_mle(r_addr)?, + None => { + let pow2 = 1usize + .checked_shl(r_addr.len() as u32) + .ok_or_else(|| PiCcsError::InvalidInput("shout gamma 2^ell overflow".into()))?; + if lut_inst.table.len() < pow2 { + return Err(PiCcsError::InvalidInput(format!( + "shout gamma table too short: len={} < 2^ell={pow2}", + lut_inst.table.len() + ))); + } + let mut acc = K::ZERO; + for (i, &v) in lut_inst.table.iter().enumerate().take(pow2) { + let w = neo_memory::mle::chi_at_index(r_addr, i); + acc += K::from(v) * w; + } + acc + } + }; + + let w = weights[slot]; + value_claim += w * lane_oracles.value_claim; + adapter_claim += w * table_eval_at_r_addr * lane_oracles.adapter_claim; + weighted_table.push(w * table_eval_at_r_addr); + + value_cols.push(lane.has_lookup.clone()); + value_cols.push(lane.val.clone()); + + adapter_cols.push(lane.has_lookup.clone()); + adapter_cols.extend(lane.addr_bits.iter().cloned()); + } + + let value_weights = weights.clone(); + let value_oracle = FormulaOracleSparseTime::new( + value_cols, + 3, + r_cycle, + Box::new(move |vals: &[K]| { + let mut out = K::ZERO; + let mut idx = 0usize; + for w in value_weights.iter() { + let has = vals[idx]; + idx += 1; + let val = vals[idx]; + idx += 1; + out += *w * has * val; + } + debug_assert_eq!(idx, vals.len()); + out + }), + ); + + let adapter_coeffs = weighted_table.clone(); + let adapter_r_addr = group_r_addr.ok_or_else(|| PiCcsError::ProtocolError("empty shout gamma group".into()))?; + let ell_addr = g.ell_addr; + let adapter_oracle = FormulaOracleSparseTime::new( + adapter_cols, + 2 + ell_addr, + r_cycle, + Box::new(move |vals: &[K]| { + let mut out = K::ZERO; + let mut idx = 0usize; + for coeff in adapter_coeffs.iter() { + let has = vals[idx]; + idx += 1; + let mut eq = K::ONE; + for bit_idx in 0..ell_addr { + eq *= eq_bit_affine(vals[idx], adapter_r_addr[bit_idx]); + idx += 1; + } + out += *coeff * has * eq; + } + debug_assert_eq!(idx, vals.len()); + out + }), + ); + + shout_gamma_groups.push(RouteAShoutGammaGroupOracles { + value: Box::new(value_oracle), + value_claim, + adapter: Box::new(adapter_oracle), + adapter_claim, + }); + } + + let mut twist_oracles = Vec::with_capacity(step.mem_instances.len()); + for (mem_idx, ((mem_inst, _mem_wit), pre)) in step.mem_instances.iter().zip(twist_pre.iter()).enumerate() { + let init_at_r_addr = eval_init_at_r_addr(&mem_inst.init, mem_inst.k, &pre.addr_pre.r_addr)?; + let ell_addr = mem_inst.d * mem_inst.ell; + if pre.addr_pre.r_addr.len() != ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "Twist(Route A): r_addr.len()={} != ell_addr={}", + pre.addr_pre.r_addr.len(), + ell_addr + ))); + } + + if pre.decoded.lanes.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "Twist(Route A): decoded lanes empty at mem_idx={mem_idx}" + ))); + } + + let inc_terms_at_r_addr = build_twist_inc_terms_at_r_addr(&pre.decoded.lanes, &pre.addr_pre.r_addr); + + let mut read_oracles: Vec> = Vec::with_capacity(pre.decoded.lanes.len()); + let mut write_oracles: Vec> = Vec::with_capacity(pre.decoded.lanes.len()); + for lane in pre.decoded.lanes.iter() { + read_oracles.push(Box::new(TwistReadCheckOracleSparseTime::new_with_inc_terms( + r_cycle, + lane.has_read.clone(), + lane.rv.clone(), + lane.ra_bits.clone(), + &pre.addr_pre.r_addr, + init_at_r_addr, + inc_terms_at_r_addr.clone(), + ))); + write_oracles.push(Box::new(TwistWriteCheckOracleSparseTime::new_with_inc_terms( + r_cycle, + lane.has_write.clone(), + lane.wv.clone(), + lane.inc_at_write_addr.clone(), + lane.wa_bits.clone(), + &pre.addr_pre.r_addr, + init_at_r_addr, + inc_terms_at_r_addr.clone(), + ))); + } + let read_check: Box = Box::new(SumRoundOracle::new(read_oracles)); + let write_check: Box = Box::new(SumRoundOracle::new(write_oracles)); + + let lane_count = pre.decoded.lanes.len(); + let mut bit_cols: Vec> = Vec::with_capacity(lane_count * (2 * ell_addr + 2)); + for lane in pre.decoded.lanes.iter() { + bit_cols.extend(lane.ra_bits.iter().cloned()); + bit_cols.extend(lane.wa_bits.iter().cloned()); + bit_cols.push(lane.has_read.clone()); + bit_cols.push(lane.has_write.clone()); + } + let weights = bitness_weights(r_cycle, bit_cols.len(), 0x5457_4953_54u64 + mem_idx as u64); + let bitness_oracle = LazyWeightedBitnessOracleSparseTime::new_with_cycle(r_cycle, bit_cols, weights); + let bitness: Vec> = vec![Box::new(bitness_oracle)]; + + twist_oracles.push(RouteATwistTimeOracles { + read_check, + write_check, + bitness, + }); + } + + Ok(RouteAMemoryOracles { + shout: shout_oracles, + shout_gamma_groups, + shout_event_trace_hash, + twist: twist_oracles, + }) +} diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs new file mode 100644 index 00000000..b7051636 --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs @@ -0,0 +1,839 @@ +use super::*; + +pub(crate) fn verify_route_a_decode_terminals( + core_t: usize, + step: &StepInstanceBundle, + r_time: &[K], + r_cycle: &[K], + batched_final_values: &[K], + claim_plan: &RouteATimeClaimPlan, + mem_proof: &MemSidecarProof, +) -> Result<(), PiCcsError> { + if claim_plan.decode_fields.is_none() && claim_plan.decode_immediates.is_none() { + return Ok(()); + } + + if mem_proof.wb_me_claims.len() != 1 { + return Err(PiCcsError::ProtocolError( + "W2 requires WB ME openings for shared active/bit terminals".into(), + )); + } + + let decode_layout = Rv32DecodeSidecarLayout::new(); + let decode_open_cols = rv32_decode_lookup_backed_cols(&decode_layout); + if mem_proof.wp_me_claims.len() != 1 { + return Err(PiCcsError::ProtocolError( + "W2 requires WP ME openings for shared main-trace/decode terminals".into(), + )); + } + let wp_me = &mem_proof.wp_me_claims[0]; + if wp_me.r.as_slice() != r_time { + return Err(PiCcsError::ProtocolError( + "W2 WP ME claim r mismatch (expected r_time)".into(), + )); + } + if wp_me.c != step.mcs_inst.c { + return Err(PiCcsError::ProtocolError("W2 WP ME claim commitment mismatch".into())); + } + if wp_me.m_in != step.mcs_inst.m_in { + return Err(PiCcsError::ProtocolError("W2 WP ME claim m_in mismatch".into())); + } + let trace = Rv32TraceLayout::new(); + let wp_cols = rv32_trace_wp_opening_columns(&trace); + let control_extra_cols = if control_stage_required_for_step_instance(step) { + rv32_trace_control_extra_opening_columns(&trace) + } else { + Vec::new() + }; + let decode_open_start = core_t + .checked_add(wp_cols.len()) + .and_then(|v| v.checked_add(control_extra_cols.len())) + .ok_or_else(|| PiCcsError::InvalidInput("W2 decode opening start overflow".into()))?; + let decode_open_end = decode_open_start + .checked_add(decode_open_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("W2 decode opening end overflow".into()))?; + if wp_me.y_scalars.len() < decode_open_end { + return Err(PiCcsError::ProtocolError(format!( + "W2 decode openings missing on WP ME claim (got {}, need at least {decode_open_end})", + wp_me.y_scalars.len() + ))); + } + let decode_open = &wp_me.y_scalars[decode_open_start..decode_open_end]; + let decode_open_map: BTreeMap = decode_open_cols + .iter() + .copied() + .zip(decode_open.iter().copied()) + .collect(); + let decode_open_col = |col_id: usize| -> Result { + decode_open_map + .get(&col_id) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2(shared) missing decode opening col_id={col_id}"))) + }; + let wb_me = &mem_proof.wb_me_claims[0]; + let wb_cols = rv32_trace_wb_columns(&trace); + let need_wb = core_t + .checked_add(wb_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("W2 WB opening count overflow".into()))?; + if wb_me.y_scalars.len() != need_wb { + return Err(PiCcsError::ProtocolError(format!( + "W2 WB opening length mismatch (got {}, expected {need_wb})", + wb_me.y_scalars.len() + ))); + } + let wb_open = &wb_me.y_scalars[core_t..]; + let wb_open_col = |col_id: usize| -> Result { + let idx = wb_cols + .iter() + .position(|&c| c == col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2 missing WB opening column {col_id}")))?; + Ok(wb_open[idx]) + }; + + let wp_cols = rv32_trace_wp_opening_columns(&trace); + let need_wp = core_t + .checked_add(wp_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("W2 WP opening count overflow".into()))?; + if wp_me.y_scalars.len() < need_wp { + return Err(PiCcsError::ProtocolError(format!( + "W2 WP opening length mismatch (got {}, expected at least {need_wp})", + wp_me.y_scalars.len() + ))); + } + let wp_open = &wp_me.y_scalars[core_t..need_wp]; + let wp_open_col = |col_id: usize| -> Result { + let idx = wp_cols + .iter() + .position(|&c| c == col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("W2 missing WP opening column {col_id}")))?; + Ok(wp_open[idx]) + }; + + if let Some(claim_idx) = claim_plan.decode_fields { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "w2/decode_fields claim index out of range".into(), + )); + } + let opcode_flags = [ + decode_open_col(decode_layout.op_lui)?, + decode_open_col(decode_layout.op_auipc)?, + decode_open_col(decode_layout.op_jal)?, + decode_open_col(decode_layout.op_jalr)?, + decode_open_col(decode_layout.op_branch)?, + decode_open_col(decode_layout.op_load)?, + decode_open_col(decode_layout.op_store)?, + decode_open_col(decode_layout.op_alu_imm)?, + decode_open_col(decode_layout.op_alu_reg)?, + decode_open_col(decode_layout.op_misc_mem)?, + decode_open_col(decode_layout.op_system)?, + decode_open_col(decode_layout.op_amo)?, + ]; + let funct3_is = [ + decode_open_col(decode_layout.funct3_is[0])?, + decode_open_col(decode_layout.funct3_is[1])?, + decode_open_col(decode_layout.funct3_is[2])?, + decode_open_col(decode_layout.funct3_is[3])?, + decode_open_col(decode_layout.funct3_is[4])?, + decode_open_col(decode_layout.funct3_is[5])?, + decode_open_col(decode_layout.funct3_is[6])?, + decode_open_col(decode_layout.funct3_is[7])?, + ]; + let funct3_bits = [ + decode_open_col(decode_layout.funct3_bit[0])?, + decode_open_col(decode_layout.funct3_bit[1])?, + decode_open_col(decode_layout.funct3_bit[2])?, + ]; + let funct7_bits = [ + decode_open_col(decode_layout.funct7_bit[0])?, + decode_open_col(decode_layout.funct7_bit[1])?, + decode_open_col(decode_layout.funct7_bit[2])?, + decode_open_col(decode_layout.funct7_bit[3])?, + decode_open_col(decode_layout.funct7_bit[4])?, + decode_open_col(decode_layout.funct7_bit[5])?, + decode_open_col(decode_layout.funct7_bit[6])?, + ]; + let rd_is_zero = decode_open_col(decode_layout.rd_is_zero)?; + let op_write_flags = [ + opcode_flags[0] * (K::ONE - rd_is_zero), + opcode_flags[1] * (K::ONE - rd_is_zero), + opcode_flags[2] * (K::ONE - rd_is_zero), + opcode_flags[3] * (K::ONE - rd_is_zero), + opcode_flags[7] * (K::ONE - rd_is_zero), + opcode_flags[8] * (K::ONE - rd_is_zero), + ]; + let alu_reg_table_delta = funct7_bits[5] * (funct3_is[0] + funct3_is[5]); + let alu_imm_table_delta = funct7_bits[5] * funct3_is[5]; + let rs2_decode = decode_open_col(decode_layout.rs2)?; + let imm_i = decode_open_col(decode_layout.imm_i)?; + let alu_imm_shift_rhs_delta = (funct3_is[1] + funct3_is[5]) * (rs2_decode - imm_i); + let shout_has_lookup = wp_open_col(trace.shout_has_lookup)?; + let rs1_val = wp_open_col(trace.rs1_val)?; + let shout_lhs = wp_open_col(trace.shout_lhs)?; + let shout_table_id = decode_open_col(decode_layout.shout_table_id)?; + + let selector_residuals = w2_decode_selector_residuals( + wp_open_col(trace.active)?, + decode_open_col(decode_layout.opcode)?, + opcode_flags, + funct3_is, + funct3_bits, + decode_open_col(decode_layout.op_amo)?, + ); + let bitness_residuals = w2_decode_bitness_residuals(opcode_flags, funct3_is); + let alu_branch_residuals = w2_alu_branch_lookup_residuals( + wp_open_col(trace.active)?, + wb_open_col(trace.halted)?, + shout_has_lookup, + shout_lhs, + wp_open_col(trace.shout_rhs)?, + shout_table_id, + rs1_val, + wp_open_col(trace.rs2_val)?, + decode_open_col(decode_layout.rd_has_write)?, + rd_is_zero, + wp_open_col(trace.rd_val)?, + decode_open_col(decode_layout.ram_has_read)?, + decode_open_col(decode_layout.ram_has_write)?, + wp_open_col(trace.ram_addr)?, + wp_open_col(trace.shout_val)?, + funct3_bits, + funct7_bits, + opcode_flags, + op_write_flags, + funct3_is, + alu_reg_table_delta, + alu_imm_table_delta, + alu_imm_shift_rhs_delta, + rs2_decode, + imm_i, + decode_open_col(decode_layout.imm_s)?, + ); + + let mut residuals = Vec::with_capacity(W2_FIELDS_RESIDUAL_COUNT); + residuals.extend_from_slice(&selector_residuals); + residuals.extend_from_slice(&bitness_residuals); + residuals.extend_from_slice(&alu_branch_residuals); + let mut weighted = K::ZERO; + let weights = w2_decode_pack_weight_vector(r_cycle, residuals.len()); + for (r, w) in residuals.iter().zip(weights.iter()) { + weighted += *w * *r; + } + let expected = eq_points(r_time, r_cycle) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "w2/decode_fields terminal value mismatch".into(), + )); + } + } + + if let Some(claim_idx) = claim_plan.decode_immediates { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "w2/decode_immediates claim index out of range".into(), + )); + } + let residuals = w2_decode_immediate_residuals( + decode_open_col(decode_layout.imm_i)?, + decode_open_col(decode_layout.imm_s)?, + decode_open_col(decode_layout.imm_b)?, + decode_open_col(decode_layout.imm_j)?, + [ + decode_open_col(decode_layout.rd_bit[0])?, + decode_open_col(decode_layout.rd_bit[1])?, + decode_open_col(decode_layout.rd_bit[2])?, + decode_open_col(decode_layout.rd_bit[3])?, + decode_open_col(decode_layout.rd_bit[4])?, + ], + [ + decode_open_col(decode_layout.funct3_bit[0])?, + decode_open_col(decode_layout.funct3_bit[1])?, + decode_open_col(decode_layout.funct3_bit[2])?, + ], + [ + decode_open_col(decode_layout.rs1_bit[0])?, + decode_open_col(decode_layout.rs1_bit[1])?, + decode_open_col(decode_layout.rs1_bit[2])?, + decode_open_col(decode_layout.rs1_bit[3])?, + decode_open_col(decode_layout.rs1_bit[4])?, + ], + [ + decode_open_col(decode_layout.rs2_bit[0])?, + decode_open_col(decode_layout.rs2_bit[1])?, + decode_open_col(decode_layout.rs2_bit[2])?, + decode_open_col(decode_layout.rs2_bit[3])?, + decode_open_col(decode_layout.rs2_bit[4])?, + ], + [ + decode_open_col(decode_layout.funct7_bit[0])?, + decode_open_col(decode_layout.funct7_bit[1])?, + decode_open_col(decode_layout.funct7_bit[2])?, + decode_open_col(decode_layout.funct7_bit[3])?, + decode_open_col(decode_layout.funct7_bit[4])?, + decode_open_col(decode_layout.funct7_bit[5])?, + decode_open_col(decode_layout.funct7_bit[6])?, + ], + ); + let mut weighted = K::ZERO; + let weights = w2_decode_imm_weight_vector(r_cycle, residuals.len()); + for (r, w) in residuals.iter().zip(weights.iter()) { + weighted += *w * *r; + } + let expected = eq_points(r_time, r_cycle) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "w2/decode_immediates terminal value mismatch".into(), + )); + } + } + + Ok(()) +} + +pub(crate) fn verify_route_a_width_terminals( + core_t: usize, + step: &StepInstanceBundle, + r_time: &[K], + r_cycle: &[K], + batched_final_values: &[K], + claim_plan: &RouteATimeClaimPlan, + mem_proof: &MemSidecarProof, +) -> Result<(), PiCcsError> { + let any_w3_claim = claim_plan.width_bitness.is_some() + || claim_plan.width_quiescence.is_some() + || claim_plan.width_selector_linkage.is_some() + || claim_plan.width_load_semantics.is_some() + || claim_plan.width_store_semantics.is_some(); + if !any_w3_claim { + return Ok(()); + } + + if mem_proof.wp_me_claims.len() != 1 { + return Err(PiCcsError::ProtocolError( + "W3 requires WP ME openings for shared main-trace terminals".into(), + )); + } + + let trace = Rv32TraceLayout::new(); + let width = Rv32WidthSidecarLayout::new(); + let decode = Rv32DecodeSidecarLayout::new(); + + let wp_me = &mem_proof.wp_me_claims[0]; + if wp_me.r.as_slice() != r_time { + return Err(PiCcsError::ProtocolError( + "W3 WP ME claim r mismatch (expected r_time)".into(), + )); + } + if wp_me.c != step.mcs_inst.c { + return Err(PiCcsError::ProtocolError("W3 WP ME claim commitment mismatch".into())); + } + if wp_me.m_in != step.mcs_inst.m_in { + return Err(PiCcsError::ProtocolError("W3 WP ME claim m_in mismatch".into())); + } + let wp_cols = rv32_trace_wp_opening_columns(&trace); + let need_wp = core_t + .checked_add(wp_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("W3 WP opening count overflow".into()))?; + if wp_me.y_scalars.len() < need_wp { + return Err(PiCcsError::ProtocolError(format!( + "W3 WP ME opening length mismatch (got {}, expected at least {need_wp})", + wp_me.y_scalars.len() + ))); + } + let wp_open = &wp_me.y_scalars[core_t..need_wp]; + let wp_open_col = |col_id: usize| -> Result { + let idx = wp_cols + .iter() + .position(|&c| c == col_id) + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3 missing WP opening column {col_id}")))?; + Ok(wp_open[idx]) + }; + + let decode_open_cols = rv32_decode_lookup_backed_cols(&decode); + let control_extra_cols = if control_stage_required_for_step_instance(step) { + rv32_trace_control_extra_opening_columns(&trace) + } else { + Vec::new() + }; + let decode_open_start = core_t + .checked_add(wp_cols.len()) + .and_then(|v| v.checked_add(control_extra_cols.len())) + .ok_or_else(|| PiCcsError::InvalidInput("W3 decode opening start overflow".into()))?; + let decode_open_end = decode_open_start + .checked_add(decode_open_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("W3 decode opening end overflow".into()))?; + if wp_me.y_scalars.len() < decode_open_end { + return Err(PiCcsError::ProtocolError(format!( + "W3 decode openings missing on WP ME claim (got {}, need at least {decode_open_end})", + wp_me.y_scalars.len() + ))); + } + let decode_open = &wp_me.y_scalars[decode_open_start..decode_open_end]; + let decode_open_map: BTreeMap = decode_open_cols + .iter() + .copied() + .zip(decode_open.iter().copied()) + .collect(); + let decode_open_col = |col_id: usize| -> Result { + decode_open_map + .get(&col_id) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3(shared) missing decode opening col_id={col_id}"))) + }; + let width_open_cols = rv32_width_lookup_backed_cols(&width); + let width_open_start = decode_open_end; + let width_open_end = width_open_start + .checked_add(width_open_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("W3 width opening end overflow".into()))?; + if wp_me.y_scalars.len() < width_open_end { + return Err(PiCcsError::ProtocolError(format!( + "W3 width openings missing on WP ME claim (got {}, need at least {width_open_end})", + wp_me.y_scalars.len() + ))); + } + let width_open_map: BTreeMap = wp_me.y_scalars[width_open_start..width_open_end] + .iter() + .copied() + .zip(width_open_cols.iter().copied()) + .map(|(v, col_id)| (col_id, v)) + .collect(); + let width_open_col = |col_id: usize| -> Result { + width_open_map + .get(&col_id) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError(format!("W3 missing width opening col_id={col_id}"))) + }; + + let active = wp_open_col(trace.active)?; + let rd_has_write = decode_open_col(decode.rd_has_write)?; + let rd_val = wp_open_col(trace.rd_val)?; + let ram_has_read = decode_open_col(decode.ram_has_read)?; + let ram_has_write = decode_open_col(decode.ram_has_write)?; + let ram_rv = wp_open_col(trace.ram_rv)?; + let ram_wv = wp_open_col(trace.ram_wv)?; + let rs2_val = wp_open_col(trace.rs2_val)?; + + let mut ram_rv_low_bits = [K::ZERO; 16]; + let mut rs2_low_bits = [K::ZERO; 16]; + for k in 0..16 { + ram_rv_low_bits[k] = width_open_col(width.ram_rv_low_bit[k])?; + rs2_low_bits[k] = width_open_col(width.rs2_low_bit[k])?; + } + let ram_rv_q16 = width_open_col(width.ram_rv_q16)?; + let rs2_q16 = width_open_col(width.rs2_q16)?; + let funct3_is = [ + decode_open_col(decode.funct3_is[0])?, + decode_open_col(decode.funct3_is[1])?, + decode_open_col(decode.funct3_is[2])?, + decode_open_col(decode.funct3_is[3])?, + decode_open_col(decode.funct3_is[4])?, + decode_open_col(decode.funct3_is[5])?, + decode_open_col(decode.funct3_is[6])?, + decode_open_col(decode.funct3_is[7])?, + ]; + let op_load = decode_open_col(decode.op_load)?; + let op_store = decode_open_col(decode.op_store)?; + let load_flags = [ + op_load * funct3_is[0], + op_load * funct3_is[4], + op_load * funct3_is[1], + op_load * funct3_is[5], + op_load * funct3_is[2], + ]; + let store_flags = [ + op_store * funct3_is[0], + op_store * funct3_is[1], + op_store * funct3_is[2], + ]; + + if let Some(claim_idx) = claim_plan.width_bitness { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError("w3/bitness claim index out of range".into())); + } + let mut bitness_open = Vec::with_capacity(32); + bitness_open.extend_from_slice(&ram_rv_low_bits); + bitness_open.extend_from_slice(&rs2_low_bits); + let weights = w3_bitness_weight_vector(r_cycle, bitness_open.len()); + let mut weighted = K::ZERO; + for (b, w) in bitness_open.iter().zip(weights.iter()) { + weighted += *w * *b * (*b - K::ONE); + } + let expected = eq_points(r_time, r_cycle) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError("w3/bitness terminal value mismatch".into())); + } + } + + if let Some(claim_idx) = claim_plan.width_quiescence { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "w3/quiescence claim index out of range".into(), + )); + } + let mut quiescence_open = vec![ram_rv_q16, rs2_q16]; + quiescence_open.extend_from_slice(&ram_rv_low_bits); + quiescence_open.extend_from_slice(&rs2_low_bits); + let weights = w3_quiescence_weight_vector(r_cycle, quiescence_open.len()); + let mut weighted = K::ZERO; + for (v, w) in quiescence_open.iter().zip(weights.iter()) { + weighted += *w * *v; + } + let expected = eq_points(r_time, r_cycle) * (K::ONE - active) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "w3/quiescence terminal value mismatch".into(), + )); + } + } + + if claim_plan.width_selector_linkage.is_some() { + return Err(PiCcsError::ProtocolError( + "w3/selector_linkage must be disabled in reduced width-sidecar mode".into(), + )); + } + + if let Some(claim_idx) = claim_plan.width_load_semantics { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "w3/load_semantics claim index out of range".into(), + )); + } + let residuals = w3_load_semantics_residuals( + rd_val, + ram_rv, + rd_has_write, + ram_has_read, + load_flags, + ram_rv_q16, + ram_rv_low_bits, + ); + let weights = w3_load_weight_vector(r_cycle, residuals.len()); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(weights.iter()) { + weighted += *w * *r; + } + let expected = eq_points(r_time, r_cycle) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "w3/load_semantics terminal value mismatch".into(), + )); + } + } + + if let Some(claim_idx) = claim_plan.width_store_semantics { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "w3/store_semantics claim index out of range".into(), + )); + } + let residuals = w3_store_semantics_residuals( + ram_wv, + ram_rv, + rs2_val, + rd_has_write, + ram_has_read, + ram_has_write, + store_flags, + rs2_q16, + ram_rv_low_bits, + rs2_low_bits, + ); + let weights = w3_store_weight_vector(r_cycle, residuals.len()); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(weights.iter()) { + weighted += *w * *r; + } + let expected = eq_points(r_time, r_cycle) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "w3/store_semantics terminal value mismatch".into(), + )); + } + } + + Ok(()) +} + +pub(crate) fn verify_route_a_control_terminals( + core_t: usize, + step: &StepInstanceBundle, + r_time: &[K], + r_cycle: &[K], + batched_final_values: &[K], + claim_plan: &RouteATimeClaimPlan, + mem_proof: &MemSidecarProof, +) -> Result<(), PiCcsError> { + let any_control_claim = claim_plan.control_next_pc_linear.is_some() + || claim_plan.control_next_pc_control.is_some() + || claim_plan.control_branch_semantics.is_some() + || claim_plan.control_writeback.is_some(); + if !any_control_claim { + return Ok(()); + } + + if mem_proof.wp_me_claims.len() != 1 { + return Err(PiCcsError::ProtocolError( + "control stage requires WP ME openings for main-trace terminals".into(), + )); + } + let trace = Rv32TraceLayout::new(); + let decode = Rv32DecodeSidecarLayout::new(); + + let wp_me = &mem_proof.wp_me_claims[0]; + if wp_me.r.as_slice() != r_time { + return Err(PiCcsError::ProtocolError( + "control stage WP ME claim r mismatch (expected r_time)".into(), + )); + } + if wp_me.c != step.mcs_inst.c { + return Err(PiCcsError::ProtocolError( + "control stage WP ME claim commitment mismatch".into(), + )); + } + if wp_me.m_in != step.mcs_inst.m_in { + return Err(PiCcsError::ProtocolError( + "control stage WP ME claim m_in mismatch".into(), + )); + } + let wp_base_cols = rv32_trace_wp_opening_columns(&trace); + let control_extra_cols = rv32_trace_control_extra_opening_columns(&trace); + let need_wp_min = core_t + .checked_add(wp_base_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("control stage WP opening count overflow".into()))?; + if wp_me.y_scalars.len() < need_wp_min { + return Err(PiCcsError::ProtocolError(format!( + "control stage WP ME opening length mismatch (got {}, expected at least {need_wp_min})", + wp_me.y_scalars.len() + ))); + } + let need_control_min = need_wp_min + .checked_add(control_extra_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("control stage WP+extra opening count overflow".into()))?; + if wp_me.y_scalars.len() < need_control_min { + return Err(PiCcsError::ProtocolError(format!( + "control stage requires control extra WP openings (got {}, expected at least {need_control_min})", + wp_me.y_scalars.len() + ))); + } + let wp_open = &wp_me.y_scalars[core_t..]; + let wp_open_col = |col_id: usize| -> Result { + if let Some(idx) = wp_base_cols.iter().position(|&c| c == col_id) { + return Ok(wp_open[idx]); + } + if let Some(extra_idx) = control_extra_cols.iter().position(|&c| c == col_id) { + let idx = wp_base_cols + .len() + .checked_add(extra_idx) + .ok_or_else(|| PiCcsError::InvalidInput("control stage WP extra index overflow".into()))?; + return wp_open.get(idx).copied().ok_or_else(|| { + PiCcsError::ProtocolError(format!("control stage missing WP extra opening column {col_id}")) + }); + } + Err(PiCcsError::ProtocolError(format!( + "control stage missing WP opening column {col_id}" + ))) + }; + let decode_open_cols = rv32_decode_lookup_backed_cols(&decode); + let decode_open_start = need_control_min; + let decode_open_end = decode_open_start + .checked_add(decode_open_cols.len()) + .ok_or_else(|| PiCcsError::InvalidInput("control stage decode opening end overflow".into()))?; + if wp_me.y_scalars.len() < decode_open_end { + return Err(PiCcsError::ProtocolError(format!( + "control stage decode openings missing on WP ME claim (got {}, need at least {decode_open_end})", + wp_me.y_scalars.len() + ))); + } + let decode_open = &wp_me.y_scalars[decode_open_start..decode_open_end]; + let decode_open_map: BTreeMap = decode_open_cols + .iter() + .copied() + .zip(decode_open.iter().copied()) + .collect(); + let decode_open_col = |col_id: usize| -> Result { + decode_open_map + .get(&col_id) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError(format!("control(shared) missing decode opening col_id={col_id}"))) + }; + + let active = wp_open_col(trace.active)?; + let pc_before = wp_open_col(trace.pc_before)?; + let pc_after = wp_open_col(trace.pc_after)?; + let rs1_val = wp_open_col(trace.rs1_val)?; + let rd_val = wp_open_col(trace.rd_val)?; + let jalr_drop_bit = wp_open_col(trace.jalr_drop_bit)?; + let shout_val = wp_open_col(trace.shout_val)?; + let funct3_bits = [ + decode_open_col(decode.funct3_bit[0])?, + decode_open_col(decode.funct3_bit[1])?, + decode_open_col(decode.funct3_bit[2])?, + ]; + let rs1_bits = [ + decode_open_col(decode.rs1_bit[0])?, + decode_open_col(decode.rs1_bit[1])?, + decode_open_col(decode.rs1_bit[2])?, + decode_open_col(decode.rs1_bit[3])?, + decode_open_col(decode.rs1_bit[4])?, + ]; + let rs2_bits = [ + decode_open_col(decode.rs2_bit[0])?, + decode_open_col(decode.rs2_bit[1])?, + decode_open_col(decode.rs2_bit[2])?, + decode_open_col(decode.rs2_bit[3])?, + decode_open_col(decode.rs2_bit[4])?, + ]; + let funct7_bits = [ + decode_open_col(decode.funct7_bit[0])?, + decode_open_col(decode.funct7_bit[1])?, + decode_open_col(decode.funct7_bit[2])?, + decode_open_col(decode.funct7_bit[3])?, + decode_open_col(decode.funct7_bit[4])?, + decode_open_col(decode.funct7_bit[5])?, + decode_open_col(decode.funct7_bit[6])?, + ]; + + let op_lui = decode_open_col(decode.op_lui)?; + let op_auipc = decode_open_col(decode.op_auipc)?; + let op_jal = decode_open_col(decode.op_jal)?; + let op_jalr = decode_open_col(decode.op_jalr)?; + let op_branch = decode_open_col(decode.op_branch)?; + let op_load = decode_open_col(decode.op_load)?; + let op_store = decode_open_col(decode.op_store)?; + let op_alu_imm = decode_open_col(decode.op_alu_imm)?; + let op_alu_reg = decode_open_col(decode.op_alu_reg)?; + let op_misc_mem = decode_open_col(decode.op_misc_mem)?; + let op_system = decode_open_col(decode.op_system)?; + let op_amo = decode_open_col(decode.op_amo)?; + let rd_is_zero = decode_open_col(decode.rd_is_zero)?; + let op_lui_write = op_lui * (K::ONE - rd_is_zero); + let op_auipc_write = op_auipc * (K::ONE - rd_is_zero); + let op_jal_write = op_jal * (K::ONE - rd_is_zero); + let op_jalr_write = op_jalr * (K::ONE - rd_is_zero); + let imm_i = decode_open_col(decode.imm_i)?; + let imm_b = decode_open_col(decode.imm_b)?; + let imm_j = decode_open_col(decode.imm_j)?; + let funct3_is6 = decode_open_col(decode.funct3_is[6])?; + let funct3_is7 = decode_open_col(decode.funct3_is[7])?; + + if let Some(claim_idx) = claim_plan.control_next_pc_linear { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "control/next_pc_linear claim index out of range".into(), + )); + } + let residual = control_next_pc_linear_residual( + pc_before, + pc_after, + op_lui, + op_auipc, + op_load, + op_store, + op_alu_imm, + op_alu_reg, + op_misc_mem, + op_system, + op_amo, + ); + let weights = control_next_pc_linear_weight_vector(r_cycle, 1); + let expected = eq_points(r_time, r_cycle) * weights[0] * residual; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "control/next_pc_linear terminal value mismatch".into(), + )); + } + } + + if let Some(claim_idx) = claim_plan.control_next_pc_control { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "control/next_pc_control claim index out of range".into(), + )); + } + let residuals = control_next_pc_control_residuals( + active, + pc_before, + pc_after, + rs1_val, + jalr_drop_bit, + imm_i, + imm_b, + imm_j, + op_jal, + op_jalr, + op_branch, + shout_val, + funct3_bits[0], + ); + let weights = control_next_pc_control_weight_vector(r_cycle, residuals.len()); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(weights.iter()) { + weighted += *w * *r; + } + let expected = eq_points(r_time, r_cycle) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "control/next_pc_control terminal value mismatch".into(), + )); + } + } + + if let Some(claim_idx) = claim_plan.control_branch_semantics { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "control/branch_semantics claim index out of range".into(), + )); + } + let residuals = control_branch_semantics_residuals( + op_branch, + shout_val, + funct3_bits[0], + funct3_bits[1], + funct3_bits[2], + funct3_is6, + funct3_is7, + ); + let weights = control_branch_semantics_weight_vector(r_cycle, residuals.len()); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(weights.iter()) { + weighted += *w * *r; + } + let expected = eq_points(r_time, r_cycle) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "control/branch_semantics terminal value mismatch".into(), + )); + } + } + + if let Some(claim_idx) = claim_plan.control_writeback { + if claim_idx >= batched_final_values.len() { + return Err(PiCcsError::ProtocolError( + "control/writeback claim index out of range".into(), + )); + } + let imm_u = control_imm_u_from_bits(funct3_bits, rs1_bits, rs2_bits, funct7_bits); + let residuals = control_writeback_residuals( + rd_val, + pc_before, + imm_u, + op_lui_write, + op_auipc_write, + op_jal_write, + op_jalr_write, + ); + let weights = control_writeback_weight_vector(r_cycle, residuals.len()); + let mut weighted = K::ZERO; + for (r, w) in residuals.iter().zip(weights.iter()) { + weighted += *w * *r; + } + let expected = eq_points(r_time, r_cycle) * weighted; + if batched_final_values[claim_idx] != expected { + return Err(PiCcsError::ProtocolError( + "control/writeback terminal value mismatch".into(), + )); + } + } + + Ok(()) +} diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_verify.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_verify.rs new file mode 100644 index 00000000..4054c2b3 --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_verify.rs @@ -0,0 +1,1062 @@ +use super::*; + +pub fn verify_route_a_memory_step( + tr: &mut Poseidon2Transcript, + cpu_bus: &BusLayout, + m: usize, + core_t: usize, + step: &StepInstanceBundle, + prev_step: Option<&StepInstanceBundle>, + ccs_out0: &MeInstance, + r_time: &[K], + r_cycle: &[K], + batched_final_values: &[K], + batched_claimed_sums: &[K], + claim_idx_start: usize, + mem_proof: &MemSidecarProof, + shout_pre: &[ShoutAddrPreVerifyData], + twist_pre: &[TwistAddrPreVerifyData], + step_idx: usize, +) -> Result { + let chi_cycle_at_r_time = eq_points(r_time, r_cycle); + if ccs_out0.r.as_slice() != r_time { + return Err(PiCcsError::ProtocolError( + "CPU ME output r mismatch (expected shared r_time)".into(), + )); + } + let trace_mode = wb_wp_required_for_step_instance(step); + let cpu_link = if trace_mode { + extract_trace_cpu_link_openings(m, core_t, cpu_bus.bus_cols, step, ccs_out0)? + } else { + None + }; + let enforce_trace_shout_linkage = trace_mode && !step.lut_insts.is_empty(); + if enforce_trace_shout_linkage && cpu_link.is_none() { + return Err(PiCcsError::ProtocolError( + "missing CPU trace linkage openings in shared-bus mode".into(), + )); + } + let has_prev = prev_step.is_some(); + if let Some(prev) = prev_step { + if prev.mem_insts.len() != step.mem_insts.len() { + return Err(PiCcsError::InvalidInput(format!( + "Twist rollover requires stable mem instance count: prev has {}, current has {}", + prev.mem_insts.len(), + step.mem_insts.len() + ))); + } + for (idx, (prev_inst, inst)) in prev.mem_insts.iter().zip(step.mem_insts.iter()).enumerate() { + if prev_inst.d != inst.d + || prev_inst.ell != inst.ell + || prev_inst.k != inst.k + || prev_inst.lanes != inst.lanes + { + return Err(PiCcsError::InvalidInput(format!( + "Twist rollover requires stable geometry at mem_idx={}: prev (k={}, d={}, ell={}, lanes={}) vs cur (k={}, d={}, ell={}, lanes={})", + idx, + prev_inst.k, + prev_inst.d, + prev_inst.ell, + prev_inst.lanes, + inst.k, + inst.d, + inst.ell, + inst.lanes + ))); + } + } + } + + for (idx, inst) in step.lut_insts.iter().enumerate() { + if !inst.comms.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "shared CPU bus requires metadata-only Shout instances (comms must be empty, lut_idx={idx})" + ))); + } + } + for (idx, inst) in step.mem_insts.iter().enumerate() { + if !inst.comms.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "shared CPU bus requires metadata-only Twist instances (comms must be empty, mem_idx={idx})" + ))); + } + } + if let Some(prev) = prev_step { + for (idx, inst) in prev.lut_insts.iter().enumerate() { + if !inst.comms.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "shared CPU bus requires metadata-only Shout instances (comms must be empty, prev lut_idx={idx})" + ))); + } + } + for (idx, inst) in prev.mem_insts.iter().enumerate() { + if !inst.comms.is_empty() { + return Err(PiCcsError::InvalidInput(format!( + "shared CPU bus requires metadata-only Twist instances (comms must be empty, prev mem_idx={idx})" + ))); + } + } + } + + let proofs_mem = &mem_proof.proofs; + + if cpu_bus.shout_cols.len() != step.lut_insts.len() || cpu_bus.twist_cols.len() != step.mem_insts.len() { + return Err(PiCcsError::InvalidInput( + "shared_cpu_bus layout mismatch for step (instance counts)".into(), + )); + } + + let bus_y_base_time = if cpu_bus.bus_cols > 0 { + let min_len = core_t + .checked_add(cpu_bus.bus_cols) + .ok_or_else(|| PiCcsError::InvalidInput("core_t + bus_cols overflow".into()))?; + if ccs_out0.y_scalars.len() < min_len { + return Err(PiCcsError::InvalidInput( + "CPU y_scalars too short for shared-bus openings".into(), + )); + } + core_t + } else { + 0usize + }; + let wb_enabled = wb_wp_required_for_step_instance(step); + let wp_enabled = wb_wp_required_for_step_instance(step); + let w2_enabled = decode_stage_required_for_step_instance(step); + let w3_enabled = width_stage_required_for_step_instance(step); + let control_enabled = control_stage_required_for_step_instance(step); + let claim_plan = RouteATimeClaimPlan::build( + step, + claim_idx_start, + wb_enabled, + wp_enabled, + w2_enabled, + w3_enabled, + control_enabled, + )?; + if claim_plan.claim_idx_end > batched_final_values.len() { + return Err(PiCcsError::InvalidInput(format!( + "batched_final_values too short (need at least {}, have {})", + claim_plan.claim_idx_end, + batched_final_values.len() + ))); + } + if claim_plan.claim_idx_end > batched_claimed_sums.len() { + return Err(PiCcsError::InvalidInput(format!( + "batched_claimed_sums too short (need at least {}, have {})", + claim_plan.claim_idx_end, + batched_claimed_sums.len() + ))); + } + + let expected_proofs = step.lut_insts.len() + step.mem_insts.len(); + if proofs_mem.len() != expected_proofs { + return Err(PiCcsError::InvalidInput(format!( + "mem proof count mismatch (expected {}, got {})", + expected_proofs, + proofs_mem.len() + ))); + } + let total_shout_lanes: usize = step.lut_insts.iter().map(|inst| inst.lanes.max(1)).sum(); + if shout_pre.len() != total_shout_lanes { + return Err(PiCcsError::InvalidInput(format!( + "shout pre-time count mismatch (expected total_lanes={}, got {})", + total_shout_lanes, + shout_pre.len() + ))); + } + if twist_pre.len() != step.mem_insts.len() { + return Err(PiCcsError::InvalidInput(format!( + "twist pre-time count mismatch (expected {}, got {})", + step.mem_insts.len(), + twist_pre.len() + ))); + } + + let mut twist_time_openings: Vec = Vec::with_capacity(step.mem_insts.len()); + + // Shout instances first. + let mut shout_lane_base: usize = 0; + let mut shout_trace_sums = ShoutTraceLinkSums::default(); + #[derive(Clone)] + struct ShoutGammaLaneVerifyData { + has_lookup: K, + val: K, + addr_bits: Vec, + pre: ShoutAddrPreVerifyData, + } + let mut shout_addr_range_counts = std::collections::HashMap::<(usize, usize), usize>::new(); + for inst_cols in cpu_bus.shout_cols.iter() { + for lane_cols in inst_cols.lanes.iter() { + let key = (lane_cols.addr_bits.start, lane_cols.addr_bits.end); + *shout_addr_range_counts.entry(key).or_insert(0) += 1; + } + } + let mut shout_gamma_lane_data: Vec> = vec![None; total_shout_lanes]; + for (proof_idx, inst) in step.lut_insts.iter().enumerate() { + match &proofs_mem[proof_idx] { + MemOrLutProof::Shout(_proof) => {} + _ => return Err(PiCcsError::InvalidInput("expected Shout proof".into())), + } + if matches!( + inst.table_spec, + Some(LutTableSpec::RiscvOpcodePacked { .. } | LutTableSpec::RiscvOpcodeEventTablePacked { .. }) + ) { + return Err(PiCcsError::InvalidInput( + "packed RISC-V Shout table specs are not supported on the shared CPU bus".into(), + )); + } + + let ell_addr = inst.d * inst.ell; + let expected_lanes = inst.lanes.max(1); + let lane_table_id = if enforce_trace_shout_linkage { + rv32_trace_link_table_id_from_spec(&inst.table_spec)?.map(|table_id| K::from(F::from_u64(table_id as u64))) + } else { + None + }; + + let inst_cols = cpu_bus + .shout_cols + .get(proof_idx) + .ok_or_else(|| PiCcsError::InvalidInput("shared_cpu_bus layout mismatch (shout)".into()))?; + if inst_cols.lanes.len() != expected_lanes { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at lut_idx={proof_idx}: bus shout lanes={} but instance expects {expected_lanes}", + inst_cols.lanes.len() + ))); + } + + struct ShoutLaneOpen { + addr_bits: Vec, + has_lookup: K, + val: K, + shared_addr_group: bool, + shared_addr_group_size: usize, + } + let mut lane_opens: Vec = Vec::with_capacity(expected_lanes); + for (lane_idx, shout_cols) in inst_cols.lanes.iter().enumerate() { + if shout_cols.addr_bits.end - shout_cols.addr_bits.start != ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at lut_idx={proof_idx}, lane_idx={lane_idx}: expected ell_addr={ell_addr}" + ))); + } + + let mut addr_bits_open = Vec::with_capacity(ell_addr); + for (_j, col_id) in shout_cols.addr_bits.clone().enumerate() { + addr_bits_open.push( + ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, col_id)) + .copied() + .ok_or_else(|| { + PiCcsError::ProtocolError("CPU y_scalars missing Shout addr_bits opening".into()) + })?, + ); + } + let has_lookup_open = ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, shout_cols.has_lookup)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Shout has_lookup opening".into()))?; + let val_open = ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, shout_cols.primary_val())) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Shout val opening".into()))?; + let key = (shout_cols.addr_bits.start, shout_cols.addr_bits.end); + let shared_addr_group_size = shout_addr_range_counts.get(&key).copied().unwrap_or(0); + let shared_addr_group = shared_addr_group_size > 1; + + lane_opens.push(ShoutLaneOpen { + addr_bits: addr_bits_open, + has_lookup: has_lookup_open, + val: val_open, + shared_addr_group, + shared_addr_group_size, + }); + } + + let shout_claims = claim_plan + .shout + .get(proof_idx) + .ok_or_else(|| PiCcsError::ProtocolError(format!("missing Shout claim schedule at index {}", proof_idx)))?; + if shout_claims.lanes.len() != expected_lanes { + return Err(PiCcsError::ProtocolError(format!( + "Shout claim schedule lane count mismatch at lut_idx={proof_idx}: expected {expected_lanes}, got {}", + shout_claims.lanes.len() + ))); + } + if shout_lane_base + .checked_add(expected_lanes) + .ok_or_else(|| PiCcsError::ProtocolError("shout lane index overflow".into()))? + > shout_pre.len() + { + return Err(PiCcsError::ProtocolError("Shout pre-time lane indexing drift".into())); + } + + // Route A Shout ordering in batched_time: + // - value (time rounds only) per lane + // - adapter (time rounds only) per lane + // - aggregated bitness for (addr_bits, has_lookup) + { + let mut opens: Vec = Vec::with_capacity(expected_lanes * (ell_addr + 1)); + for lane in lane_opens.iter() { + opens.extend_from_slice(&lane.addr_bits); + opens.push(lane.has_lookup); + } + let weights = bitness_weights(r_cycle, opens.len(), 0x5348_4F55_54u64 + proof_idx as u64); + let mut acc = K::ZERO; + for (w, b) in weights.iter().zip(opens.iter()) { + acc += *w * *b * (*b - K::ONE); + } + let expected = chi_cycle_at_r_time * acc; + if expected != batched_final_values[shout_claims.bitness] { + return Err(PiCcsError::ProtocolError( + "shout/bitness terminal value mismatch".into(), + )); + } + } + + for (lane_idx, lane) in lane_opens.iter().enumerate() { + if let Some(lane_table_id) = lane_table_id { + shout_trace_sums.has_lookup += lane.has_lookup; + shout_trace_sums.val += lane.val; + shout_trace_sums.table_id += lane.has_lookup * lane_table_id; + let (lhs, rhs) = unpack_interleaved_halves_lsb(&lane.addr_bits)?; + if lane.shared_addr_group { + let inv_count = K::from_u64(lane.shared_addr_group_size as u64).inverse(); + shout_trace_sums.lhs += lhs * inv_count; + shout_trace_sums.rhs += rhs * inv_count; + } else { + shout_trace_sums.lhs += lhs; + shout_trace_sums.rhs += rhs; + } + } + + let pre = shout_pre.get(shout_lane_base + lane_idx).ok_or_else(|| { + PiCcsError::InvalidInput(format!( + "missing pre-time Shout lane data at index {}", + shout_lane_base + lane_idx + )) + })?; + let lane_claims = shout_claims + .lanes + .get(lane_idx) + .ok_or_else(|| PiCcsError::ProtocolError("shout claim schedule lane idx drift".into()))?; + + if lane_claims.gamma_group.is_some() { + if !pre.is_active { + if pre.addr_claim_sum != K::ZERO || pre.addr_final != K::ZERO || lane.has_lookup != K::ZERO { + return Err(PiCcsError::ProtocolError( + "shout gamma lane inactive-row invariants violated".into(), + )); + } + } + shout_gamma_lane_data[shout_lane_base + lane_idx] = Some(ShoutGammaLaneVerifyData { + has_lookup: lane.has_lookup, + val: lane.val, + addr_bits: lane.addr_bits.clone(), + pre: pre.clone(), + }); + } else { + let value_idx = lane_claims + .value + .ok_or_else(|| PiCcsError::ProtocolError("missing shout value claim idx".into()))?; + let adapter_idx = lane_claims + .adapter + .ok_or_else(|| PiCcsError::ProtocolError("missing shout adapter claim idx".into()))?; + let value_claim = batched_claimed_sums[value_idx]; + let value_final = batched_final_values[value_idx]; + let adapter_claim = batched_claimed_sums[adapter_idx]; + let adapter_final = batched_final_values[adapter_idx]; + + let expected_value_final = chi_cycle_at_r_time * lane.has_lookup * lane.val; + if expected_value_final != value_final { + return Err(PiCcsError::ProtocolError("shout value terminal value mismatch".into())); + } + + let eq_addr = eq_bits_prod(&lane.addr_bits, &pre.r_addr)?; + let expected_adapter_final = chi_cycle_at_r_time * lane.has_lookup * eq_addr; + if expected_adapter_final != adapter_final { + return Err(PiCcsError::ProtocolError( + "shout adapter terminal value mismatch".into(), + )); + } + + if value_claim != pre.addr_claim_sum { + return Err(PiCcsError::ProtocolError( + "shout value claimed sum != addr claimed sum".into(), + )); + } + + if pre.is_active { + let expected_addr_final = pre.table_eval_at_r_addr * adapter_claim; + if expected_addr_final != pre.addr_final { + return Err(PiCcsError::ProtocolError("shout addr terminal value mismatch".into())); + } + } else { + // If we skipped the addr-pre sumcheck, the only sound case is "no lookups". + // Enforce this by requiring the addr claim + adapter claim to be zero. + if pre.addr_claim_sum != K::ZERO { + return Err(PiCcsError::ProtocolError( + "shout addr-pre skipped but addr claim is nonzero".into(), + )); + } + if adapter_claim != K::ZERO { + return Err(PiCcsError::ProtocolError( + "shout addr-pre skipped but adapter claim is nonzero".into(), + )); + } + if pre.addr_final != K::ZERO { + return Err(PiCcsError::ProtocolError( + "shout addr-pre skipped but addr_final is nonzero".into(), + )); + } + } + } + } + + shout_lane_base += expected_lanes; + } + if shout_lane_base != shout_pre.len() { + return Err(PiCcsError::ProtocolError( + "shout pre-time lanes not fully consumed".into(), + )); + } + if !step.lut_insts.is_empty() && enforce_trace_shout_linkage { + let cpu = cpu_link + .ok_or_else(|| PiCcsError::ProtocolError("missing CPU trace linkage openings in shared-bus mode".into()))?; + let expected_table_id = if decode_stage_required_for_step_instance(step) { + Some(expected_trace_shout_table_id_from_openings( + core_t, step, mem_proof, r_time, + )?) + } else { + None + }; + verify_non_event_trace_shout_linkage(cpu, shout_trace_sums, expected_table_id)?; + } + + for group in claim_plan.shout_gamma_groups.iter() { + let weights = bitness_weights(r_cycle, group.lanes.len(), 0x5348_5F47_414D_4Du64 ^ group.key); + let value_claim = batched_claimed_sums[group.value]; + let value_final = batched_final_values[group.value]; + let adapter_claim = batched_claimed_sums[group.adapter]; + let adapter_final = batched_final_values[group.adapter]; + + let mut expected_value_claim = K::ZERO; + let mut expected_value_final = K::ZERO; + let mut expected_adapter_claim = K::ZERO; + let mut expected_adapter_final = K::ZERO; + for (slot, lane_ref) in group.lanes.iter().enumerate() { + let lane = shout_gamma_lane_data + .get(lane_ref.flat_lane_idx) + .and_then(|x| x.as_ref()) + .ok_or_else(|| PiCcsError::ProtocolError("missing shout gamma lane verify data".into()))?; + let w = weights[slot]; + let eq_addr = eq_bits_prod(&lane.addr_bits, &lane.pre.r_addr)?; + expected_value_claim += w * lane.pre.addr_claim_sum; + expected_value_final += w * lane.has_lookup * lane.val; + expected_adapter_claim += w * lane.pre.addr_final; + expected_adapter_final += w * lane.pre.table_eval_at_r_addr * lane.has_lookup * eq_addr; + } + expected_value_final *= chi_cycle_at_r_time; + expected_adapter_final *= chi_cycle_at_r_time; + + if value_claim != expected_value_claim { + return Err(PiCcsError::ProtocolError( + "shout gamma value claimed sum mismatch".into(), + )); + } + if value_final != expected_value_final { + return Err(PiCcsError::ProtocolError("shout gamma value terminal mismatch".into())); + } + if adapter_claim != expected_adapter_claim { + return Err(PiCcsError::ProtocolError( + "shout gamma adapter claimed sum mismatch".into(), + )); + } + if adapter_final != expected_adapter_final { + return Err(PiCcsError::ProtocolError( + "shout gamma adapter terminal mismatch".into(), + )); + } + } + + // Twist instances next. + let proof_mem_offset = step.lut_insts.len(); + + // -------------------------------------------------------------------- + // Twist time checks at addr-pre `r_addr`. + // -------------------------------------------------------------------- + for (i_mem, inst) in step.mem_insts.iter().enumerate() { + let twist_proof = match &proofs_mem[proof_mem_offset + i_mem] { + MemOrLutProof::Twist(proof) => proof, + _ => return Err(PiCcsError::InvalidInput("expected Twist proof".into())), + }; + let layout = inst.twist_layout(); + let ell_addr = layout + .lanes + .get(0) + .ok_or_else(|| PiCcsError::InvalidInput("TwistWitnessLayout has no lanes".into()))? + .ell_addr; + + let twist_inst_cols = cpu_bus + .twist_cols + .get(i_mem) + .ok_or_else(|| PiCcsError::InvalidInput("shared_cpu_bus layout mismatch (twist)".into()))?; + let expected_lanes = inst.lanes.max(1); + if twist_inst_cols.lanes.len() != expected_lanes { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at mem_idx={i_mem}: expected lanes={expected_lanes}, got {}", + twist_inst_cols.lanes.len() + ))); + } + + struct TwistLaneTimeOpen { + ra_bits: Vec, + wa_bits: Vec, + has_read: K, + has_write: K, + wv: K, + rv: K, + inc: K, + } + + let mut lane_opens: Vec = Vec::with_capacity(twist_inst_cols.lanes.len()); + for (lane_idx, twist_cols) in twist_inst_cols.lanes.iter().enumerate() { + if twist_cols.ra_bits.end - twist_cols.ra_bits.start != ell_addr + || twist_cols.wa_bits.end - twist_cols.wa_bits.start != ell_addr + { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at mem_idx={i_mem}, lane={lane_idx}: expected ell_addr={ell_addr}" + ))); + } + + let mut ra_bits_open = Vec::with_capacity(ell_addr); + for col_id in twist_cols.ra_bits.clone() { + ra_bits_open.push( + ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, col_id)) + .copied() + .ok_or_else(|| { + PiCcsError::ProtocolError("CPU y_scalars missing Twist ra_bits opening".into()) + })?, + ); + } + let mut wa_bits_open = Vec::with_capacity(ell_addr); + for col_id in twist_cols.wa_bits.clone() { + wa_bits_open.push( + ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, col_id)) + .copied() + .ok_or_else(|| { + PiCcsError::ProtocolError("CPU y_scalars missing Twist wa_bits opening".into()) + })?, + ); + } + + let has_read_open = ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.has_read)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist has_read opening".into()))?; + let has_write_open = ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.has_write)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist has_write opening".into()))?; + let wv_open = ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.wv)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist wv opening".into()))?; + let rv_open = ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.rv)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist rv opening".into()))?; + let inc_write_open = ccs_out0 + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_time, twist_cols.inc)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing Twist inc opening".into()))?; + + lane_opens.push(TwistLaneTimeOpen { + ra_bits: ra_bits_open, + wa_bits: wa_bits_open, + has_read: has_read_open, + has_write: has_write_open, + wv: wv_open, + rv: rv_open, + inc: inc_write_open, + }); + } + + let pre = twist_pre + .get(i_mem) + .ok_or_else(|| PiCcsError::InvalidInput(format!("missing Twist pre-time data at index {}", i_mem)))?; + let r_addr = &pre.r_addr; + if r_addr.len() != ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "Twist r_addr.len()={}, expected ell_addr={}", + r_addr.len(), + ell_addr + ))); + } + + let twist_claims = claim_plan + .twist + .get(i_mem) + .ok_or_else(|| PiCcsError::ProtocolError(format!("missing Twist claim schedule at index {}", i_mem)))?; + + // Route A Twist ordering in batched_time: + // - read_check (time rounds only) + // - write_check (time rounds only) + // - bitness for ra_bits then wa_bits then has_read then has_write (time-only) + let read_check_claim = batched_claimed_sums[twist_claims.read_check]; + let read_check_final = batched_final_values[twist_claims.read_check]; + let write_check_claim = batched_claimed_sums[twist_claims.write_check]; + let write_check_final = batched_final_values[twist_claims.write_check]; + + if read_check_claim != pre.read_check_claim_sum { + return Err(PiCcsError::ProtocolError( + "twist read_check claimed sum != addr-pre final".into(), + )); + } + if write_check_claim != pre.write_check_claim_sum { + return Err(PiCcsError::ProtocolError( + "twist write_check claimed sum != addr-pre final".into(), + )); + } + + // Aggregated bitness terminal check (ra_bits, wa_bits, has_read, has_write). + { + let mut opens: Vec = Vec::with_capacity(expected_lanes * (2 * ell_addr + 2)); + for lane in lane_opens.iter() { + opens.extend_from_slice(&lane.ra_bits); + opens.extend_from_slice(&lane.wa_bits); + opens.push(lane.has_read); + opens.push(lane.has_write); + } + let weights = bitness_weights(r_cycle, opens.len(), 0x5457_4953_54u64 + i_mem as u64); + let mut acc = K::ZERO; + for (w, b) in weights.iter().zip(opens.iter()) { + acc += *w * *b * (*b - K::ONE); + } + let expected = chi_cycle_at_r_time * acc; + if expected != batched_final_values[twist_claims.bitness] { + return Err(PiCcsError::ProtocolError( + "twist/bitness terminal value mismatch".into(), + )); + } + } + + let val_eval = twist_proof + .val_eval + .as_ref() + .ok_or_else(|| PiCcsError::InvalidInput("Twist(Route A): missing val_eval proof".into()))?; + + let init_at_r_addr = eval_init_at_r_addr(&inst.init, inst.k, r_addr)?; + let claimed_val = init_at_r_addr + val_eval.claimed_inc_sum_lt; + + // Terminal checks for read_check / write_check at (r_time, r_addr). + let mut expected_read_check_final = K::ZERO; + let mut expected_write_check_final = K::ZERO; + for lane in lane_opens.iter() { + let read_eq_addr = eq_bits_prod(&lane.ra_bits, r_addr)?; + expected_read_check_final += chi_cycle_at_r_time * lane.has_read * (claimed_val - lane.rv) * read_eq_addr; + + let write_eq_addr = eq_bits_prod(&lane.wa_bits, r_addr)?; + expected_write_check_final += + chi_cycle_at_r_time * lane.has_write * (lane.wv - claimed_val - lane.inc) * write_eq_addr; + } + if expected_read_check_final != read_check_final { + return Err(PiCcsError::ProtocolError( + "twist/read_check terminal value mismatch".into(), + )); + } + + if expected_write_check_final != write_check_final { + return Err(PiCcsError::ProtocolError( + "twist/write_check terminal value mismatch".into(), + )); + } + + twist_time_openings.push(TwistTimeLaneOpenings { + lanes: lane_opens + .into_iter() + .map(|lane| TwistTimeLaneOpeningsLane { + wa_bits: lane.wa_bits, + has_write: lane.has_write, + inc_at_write_addr: lane.inc, + }) + .collect(), + }); + } + + // -------------------------------------------------------------------- + // Phase 2: Verify batched Twist val-eval sum-check, deriving shared r_val. + // -------------------------------------------------------------------- + let mut r_val: Vec = Vec::new(); + let mut val_eval_finals: Vec = Vec::new(); + if !step.mem_insts.is_empty() { + let plan = crate::memory_sidecar::claim_plan::TwistValEvalClaimPlan::build(step.mem_insts.iter(), has_prev); + let claim_count = plan.claim_count; + + let mut per_claim_rounds: Vec>> = Vec::with_capacity(claim_count); + let mut per_claim_sums: Vec = Vec::with_capacity(claim_count); + let mut bind_claims: Vec<(u8, K)> = Vec::with_capacity(claim_count); + let mut claim_idx = 0usize; + + for (i_mem, _inst) in step.mem_insts.iter().enumerate() { + let twist_proof = match &proofs_mem[proof_mem_offset + i_mem] { + MemOrLutProof::Twist(proof) => proof, + _ => return Err(PiCcsError::InvalidInput("expected Twist proof".into())), + }; + let val = twist_proof + .val_eval + .as_ref() + .ok_or_else(|| PiCcsError::InvalidInput("Twist(Route A): missing val_eval proof".into()))?; + + per_claim_rounds.push(val.rounds_lt.clone()); + per_claim_sums.push(val.claimed_inc_sum_lt); + bind_claims.push((plan.bind_tags[claim_idx], val.claimed_inc_sum_lt)); + claim_idx += 1; + + per_claim_rounds.push(val.rounds_total.clone()); + per_claim_sums.push(val.claimed_inc_sum_total); + bind_claims.push((plan.bind_tags[claim_idx], val.claimed_inc_sum_total)); + claim_idx += 1; + + if has_prev { + let prev_total = val.claimed_prev_inc_sum_total.ok_or_else(|| { + PiCcsError::InvalidInput("Twist(Route A): missing claimed_prev_inc_sum_total".into()) + })?; + let prev_rounds = val + .rounds_prev_total + .clone() + .ok_or_else(|| PiCcsError::InvalidInput("Twist(Route A): missing rounds_prev_total".into()))?; + per_claim_rounds.push(prev_rounds); + per_claim_sums.push(prev_total); + bind_claims.push((plan.bind_tags[claim_idx], prev_total)); + claim_idx += 1; + } else if val.claimed_prev_inc_sum_total.is_some() || val.rounds_prev_total.is_some() { + return Err(PiCcsError::InvalidInput( + "Twist(Route A): rollover fields present but prev_step is None".into(), + )); + } + } + + tr.append_message( + b"twist/val_eval/batch_start", + &(step.mem_insts.len() as u64).to_le_bytes(), + ); + tr.append_message(b"twist/val_eval/step_idx", &(step_idx as u64).to_le_bytes()); + bind_twist_val_eval_claim_sums(tr, &bind_claims); + + let (r_val_out, finals_out, ok) = verify_batched_sumcheck_rounds_ds( + tr, + b"twist/val_eval_batch", + step_idx, + &per_claim_rounds, + &per_claim_sums, + &plan.labels, + &plan.degree_bounds, + ); + if !ok { + return Err(PiCcsError::SumcheckError( + "twist val-eval batched sumcheck invalid".into(), + )); + } + if r_val_out.len() != r_time.len() { + return Err(PiCcsError::ProtocolError(format!( + "twist val-eval r_val.len()={}, expected ell_n={}", + r_val_out.len(), + r_time.len() + ))); + } + if finals_out.len() != claim_count { + return Err(PiCcsError::ProtocolError(format!( + "twist val-eval finals.len()={}, expected {}", + finals_out.len(), + claim_count + ))); + } + r_val = r_val_out; + val_eval_finals = finals_out; + + tr.append_message(b"twist/val_eval/batch_done", &[]); + } + + // Verify val-eval terminal identity against CPU ME openings at r_val. + let lt = if step.mem_insts.is_empty() { + if !r_val.is_empty() { + return Err(PiCcsError::ProtocolError( + "twist val-eval produced r_val but no mem instances are present".into(), + )); + } + K::ZERO + } else { + if r_val.len() != r_time.len() { + return Err(PiCcsError::ProtocolError(format!( + "twist val-eval r_val.len()={}, expected ell_n={}", + r_val.len(), + r_time.len() + ))); + } + lt_eval(&r_val, r_time) + }; + + let (cpu_me_val_cur, cpu_me_val_prev, bus_y_base_val) = if step.mem_insts.is_empty() { + if !mem_proof.val_me_claims.is_empty() { + return Err(PiCcsError::InvalidInput( + "proof contains val-lane CPU ME claims with no Twist instances".into(), + )); + } + (None, None, 0usize) + } else { + let expected = 1usize + usize::from(has_prev); + if mem_proof.val_me_claims.len() != expected { + return Err(PiCcsError::InvalidInput(format!( + "shared bus expects {} CPU ME claim(s) at r_val, got {}", + expected, + mem_proof.val_me_claims.len() + ))); + } + + let cpu_me_cur = mem_proof + .val_me_claims + .get(0) + .ok_or_else(|| PiCcsError::ProtocolError("missing CPU ME claim at r_val".into()))?; + if cpu_me_cur.r.as_slice() != r_val { + return Err(PiCcsError::ProtocolError( + "CPU ME(val) r mismatch (expected r_val)".into(), + )); + } + if cpu_me_cur.c != step.mcs_inst.c { + return Err(PiCcsError::ProtocolError( + "CPU ME(val) commitment mismatch (current step)".into(), + )); + } + let cpu_me_prev = if has_prev { + let prev_inst = + prev_step.ok_or_else(|| PiCcsError::ProtocolError("prev_step missing with has_prev=true".into()))?; + let cpu_me_prev = mem_proof + .val_me_claims + .get(1) + .ok_or_else(|| PiCcsError::ProtocolError("missing prev CPU ME claim at r_val".into()))?; + if cpu_me_prev.r.as_slice() != r_val { + return Err(PiCcsError::ProtocolError( + "CPU ME(val/prev) r mismatch (expected r_val)".into(), + )); + } + if cpu_me_prev.c != prev_inst.mcs_inst.c { + return Err(PiCcsError::ProtocolError("CPU ME(val/prev) commitment mismatch".into())); + } + Some(cpu_me_prev) + } else { + None + }; + + let bus_y_base_val = cpu_me_cur + .y_scalars + .len() + .checked_sub(cpu_bus.bus_cols) + .ok_or_else(|| PiCcsError::InvalidInput("CPU y_scalars too short for bus openings".into()))?; + + (Some(cpu_me_cur), cpu_me_prev, bus_y_base_val) + }; + + for (i_mem, inst) in step.mem_insts.iter().enumerate() { + let twist_proof = match &proofs_mem[proof_mem_offset + i_mem] { + MemOrLutProof::Twist(proof) => proof, + _ => return Err(PiCcsError::InvalidInput("expected Twist proof".into())), + }; + let val_eval = twist_proof + .val_eval + .as_ref() + .ok_or_else(|| PiCcsError::InvalidInput("Twist(Route A): missing val_eval proof".into()))?; + let layout = inst.twist_layout(); + let ell_addr = layout + .lanes + .get(0) + .ok_or_else(|| PiCcsError::InvalidInput("TwistWitnessLayout has no lanes".into()))? + .ell_addr; + + let cpu_me_cur = + cpu_me_val_cur.ok_or_else(|| PiCcsError::ProtocolError("missing CPU ME claim at r_val".into()))?; + + let twist_inst_cols = cpu_bus + .twist_cols + .get(i_mem) + .ok_or_else(|| PiCcsError::InvalidInput("shared_cpu_bus layout mismatch (twist)".into()))?; + let expected_lanes = inst.lanes.max(1); + if twist_inst_cols.lanes.len() != expected_lanes { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at mem_idx={i_mem}: expected lanes={expected_lanes}, got {}", + twist_inst_cols.lanes.len() + ))); + } + + let r_addr = twist_pre + .get(i_mem) + .ok_or_else(|| PiCcsError::InvalidInput(format!("missing Twist pre-time data at index {}", i_mem)))? + .r_addr + .as_slice(); + + let mut inc_at_r_addr_val = K::ZERO; + for (lane_idx, twist_cols) in twist_inst_cols.lanes.iter().enumerate() { + if twist_cols.wa_bits.end - twist_cols.wa_bits.start != ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at mem_idx={i_mem}, lane={lane_idx}: expected ell_addr={ell_addr}" + ))); + } + + let mut wa_bits_val_open = Vec::with_capacity(ell_addr); + for col_id in twist_cols.wa_bits.clone() { + wa_bits_val_open.push( + cpu_me_cur + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_val, col_id)) + .copied() + .ok_or_else(|| { + PiCcsError::ProtocolError("CPU y_scalars missing wa_bits(val) opening".into()) + })?, + ); + } + let has_write_val_open = cpu_me_cur + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_val, twist_cols.has_write)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing has_write(val) opening".into()))?; + let inc_at_write_addr_val_open = cpu_me_cur + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_val, twist_cols.inc)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing inc(val) opening".into()))?; + + let eq_wa_val = eq_bits_prod(&wa_bits_val_open, r_addr)?; + inc_at_r_addr_val += has_write_val_open * inc_at_write_addr_val_open * eq_wa_val; + } + + let expected_lt_final = inc_at_r_addr_val * lt; + let claims_per_mem = if has_prev { 3 } else { 2 }; + let base = claims_per_mem * i_mem; + if expected_lt_final != val_eval_finals[base] { + return Err(PiCcsError::ProtocolError( + "twist/val_eval_lt terminal value mismatch".into(), + )); + } + let expected_total_final = inc_at_r_addr_val; + if expected_total_final != val_eval_finals[base + 1] { + return Err(PiCcsError::ProtocolError( + "twist/val_eval_total terminal value mismatch".into(), + )); + } + + if has_prev { + let prev = + prev_step.ok_or_else(|| PiCcsError::ProtocolError("prev_step missing with has_prev=true".into()))?; + let prev_inst = prev + .mem_insts + .get(i_mem) + .ok_or_else(|| PiCcsError::ProtocolError("missing prev mem instance".into()))?; + let cpu_me_prev = cpu_me_val_prev + .ok_or_else(|| PiCcsError::ProtocolError("missing prev CPU ME claim at r_val".into()))?; + + // Terminal check for prev-total: uses previous-step openings at current r_val. + let mut inc_at_r_addr_prev = K::ZERO; + for (lane_idx, twist_cols) in twist_inst_cols.lanes.iter().enumerate() { + if twist_cols.wa_bits.end - twist_cols.wa_bits.start != ell_addr { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at mem_idx={i_mem}, lane={lane_idx}: expected ell_addr={ell_addr}" + ))); + } + + let mut wa_bits_prev_open = Vec::with_capacity(ell_addr); + for col_id in twist_cols.wa_bits.clone() { + wa_bits_prev_open.push( + cpu_me_prev + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_val, col_id)) + .copied() + .ok_or_else(|| { + PiCcsError::ProtocolError("CPU y_scalars missing wa_bits(prev) opening".into()) + })?, + ); + } + let has_write_prev_open = cpu_me_prev + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_val, twist_cols.has_write)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing has_write(prev) opening".into()))?; + let inc_prev_open = cpu_me_prev + .y_scalars + .get(cpu_bus.y_scalar_index(bus_y_base_val, twist_cols.inc)) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("CPU y_scalars missing inc(prev) opening".into()))?; + + let eq_wa_prev = eq_bits_prod(&wa_bits_prev_open, r_addr)?; + inc_at_r_addr_prev += has_write_prev_open * inc_prev_open * eq_wa_prev; + } + if inc_at_r_addr_prev != val_eval_finals[base + 2] { + return Err(PiCcsError::ProtocolError( + "twist/rollover_prev_total terminal value mismatch".into(), + )); + } + + // Enforce rollover equation: Init_i(r_addr) == Init_{i-1}(r_addr) + PrevTotal(i). + let claimed_prev_total = val_eval + .claimed_prev_inc_sum_total + .ok_or_else(|| PiCcsError::ProtocolError("twist rollover missing claimed_prev_inc_sum_total".into()))?; + let init_prev_at_r_addr = eval_init_at_r_addr(&prev_inst.init, prev_inst.k, r_addr)?; + let init_cur_at_r_addr = eval_init_at_r_addr(&inst.init, inst.k, r_addr)?; + if init_cur_at_r_addr != init_prev_at_r_addr + claimed_prev_total { + return Err(PiCcsError::ProtocolError("twist rollover init check failed".into())); + } + } + } + + verify_route_a_wb_wp_terminals( + core_t, + step, + r_time, + r_cycle, + batched_final_values, + &claim_plan, + mem_proof, + )?; + verify_route_a_decode_terminals( + core_t, + step, + r_time, + r_cycle, + batched_final_values, + &claim_plan, + mem_proof, + )?; + verify_route_a_width_terminals( + core_t, + step, + r_time, + r_cycle, + batched_final_values, + &claim_plan, + mem_proof, + )?; + verify_route_a_control_terminals( + core_t, + step, + r_time, + r_cycle, + batched_final_values, + &claim_plan, + mem_proof, + )?; + + Ok(RouteAMemoryVerifyOutput { + claim_idx_end: claim_plan.claim_idx_end, + twist_time_openings, + }) +} diff --git a/crates/neo-fold/src/memory_sidecar/memory/sparse_oracles_and_twist_pre.rs b/crates/neo-fold/src/memory_sidecar/memory/sparse_oracles_and_twist_pre.rs new file mode 100644 index 00000000..afb8523a --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/sparse_oracles_and_twist_pre.rs @@ -0,0 +1,657 @@ +use super::*; + +pub(crate) fn sparse_trace_col_from_values( + m_in: usize, + ell_n: usize, + values: &[K], +) -> Result, PiCcsError> { + let pow2_cycle = 1usize + .checked_shl(ell_n as u32) + .ok_or_else(|| PiCcsError::InvalidInput("WB/WP: 2^ell_n overflow".into()))?; + let t_len = values.len(); + if m_in + .checked_add(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("WB/WP: m_in + t_len overflow".into()))? + > pow2_cycle + { + return Err(PiCcsError::InvalidInput(format!( + "WB/WP: trace rows out of range (m_in={m_in}, t_len={t_len}, 2^ell_n={pow2_cycle})" + ))); + } + let mut entries = Vec::new(); + for (j, &v) in values.iter().enumerate() { + if v != K::ZERO { + entries.push((m_in + j, v)); + } + } + Ok(SparseIdxVec::from_entries(pow2_cycle, entries)) +} + +#[inline] +pub(crate) fn decode_k_to_u32(v: K, ctx: &str) -> Result { + let coeffs = v.as_coeffs(); + if coeffs.iter().skip(1).any(|&c| c != F::ZERO) { + return Err(PiCcsError::ProtocolError(format!( + "{ctx}: expected base-field value while decoding shared decode columns" + ))); + } + let lo = coeffs + .first() + .copied() + .ok_or_else(|| PiCcsError::ProtocolError(format!("{ctx}: missing base coefficient")))? + .as_canonical_u64(); + if lo > u32::MAX as u64 { + return Err(PiCcsError::ProtocolError(format!( + "{ctx}: value {lo} exceeds u32 range while decoding shared decode columns" + ))); + } + Ok(lo as u32) +} + +pub(crate) fn resolve_shared_decode_lookup_lut_indices( + step: &StepWitnessBundle, + decode_layout: &Rv32DecodeSidecarLayout, +) -> Result<(Vec, Vec), PiCcsError> { + let decode_open_cols = rv32_decode_lookup_backed_cols(decode_layout); + let mut decode_lut_indices = Vec::with_capacity(decode_open_cols.len()); + for &col_id in decode_open_cols.iter() { + let table_id = rv32_decode_lookup_table_id_for_col(col_id); + let idx = step + .lut_instances + .iter() + .position(|(inst, _)| inst.table_id == table_id) + .ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "W2(shared): missing decode lookup table_id={table_id} for col_id={col_id}" + )) + })?; + decode_lut_indices.push(idx); + } + + Ok((decode_open_cols, decode_lut_indices)) +} + +pub(crate) struct WeightedMaskOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + active: SparseIdxVec, + cols: Vec>, + weights: Vec, +} + +impl WeightedMaskOracleSparseTime { + pub(crate) fn new(active: SparseIdxVec, cols: Vec>, weights: Vec, r_cycle: &[K]) -> Self { + debug_assert_eq!(cols.len(), weights.len()); + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + active, + cols, + weights, + } + } +} + +impl RoundOracle for WeightedMaskOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.cols.is_empty() { + return vec![K::ZERO; points.len()]; + } + + if self.active.len() == 1 { + let gate = K::ONE - self.active.singleton_value(); + let mut acc = K::ZERO; + for (col, w) in self.cols.iter().zip(self.weights.iter()) { + acc += *w * col.singleton_value(); + } + return vec![self.prefix_eq * gate * acc; points.len()]; + } + + let mut pairs = gather_pairs_from_sparse(self.active.entries()); + for col in self.cols.iter() { + pairs.extend(gather_pairs_from_sparse(col.entries())); + } + pairs.sort_unstable(); + pairs.dedup(); + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = K::ONE - self.active.get(child0); + let gate1 = K::ONE - self.active.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let mut sum_x = K::ZERO; + for (col, w) in self.cols.iter().zip(self.weights.iter()) { + let c0 = col.get(child0); + let c1 = col.get(child1); + if c0 == K::ZERO && c1 == K::ZERO { + continue; + } + sum_x += *w * interp(c0, c1, x); + } + ys[i] += chi_x * gate_x * sum_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + 3 + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single_k(r, self.r_cycle[self.bit_idx]); + self.active.fold_round_in_place(r); + for col in self.cols.iter_mut() { + col.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +pub(crate) struct FormulaOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + cols: Vec>, + degree_bound: usize, + eval_fn: Box K>, +} + +impl FormulaOracleSparseTime { + pub(crate) fn new( + cols: Vec>, + degree_bound: usize, + r_cycle: &[K], + eval_fn: Box K>, + ) -> Self { + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + cols, + degree_bound, + eval_fn, + } + } +} + +impl RoundOracle for FormulaOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.cols.is_empty() { + return vec![K::ZERO; points.len()]; + } + + let mut pairs = Vec::new(); + for col in self.cols.iter() { + pairs.extend(gather_pairs_from_sparse(col.entries())); + } + pairs.sort_unstable(); + pairs.dedup(); + + let mut ys = vec![K::ZERO; points.len()]; + let mut vals = vec![K::ZERO; self.cols.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + for (j, col) in self.cols.iter().enumerate() { + vals[j] = interp(col.get(child0), col.get(child1), x); + } + let f_x = (self.eval_fn)(&vals); + if f_x == K::ZERO { + continue; + } + ys[i] += chi_x * f_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single_k(r, self.r_cycle[self.bit_idx]); + for col in self.cols.iter_mut() { + col.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +#[inline] +pub(crate) fn unpack_interleaved_halves_lsb(addr_bits: &[K]) -> Result<(K, K), PiCcsError> { + if !addr_bits.len().is_multiple_of(2) { + return Err(PiCcsError::InvalidInput(format!( + "shout linkage expects even ell_addr, got {}", + addr_bits.len() + ))); + } + let half_len = addr_bits.len() / 2; + let two = K::from(F::from_u64(2)); + let mut pow = K::ONE; + let mut lhs = K::ZERO; + let mut rhs = K::ZERO; + for k in 0..half_len { + lhs += pow * addr_bits[2 * k]; + rhs += pow * addr_bits[2 * k + 1]; + pow *= two; + } + Ok((lhs, rhs)) +} + +pub(crate) fn extract_trace_cpu_link_openings( + m: usize, + core_t: usize, + y_prefix_cols: usize, + step: &StepInstanceBundle, + ccs_out0: &MeInstance, +) -> Result, PiCcsError> { + if step.mem_insts.is_empty() && step.lut_insts.is_empty() { + return Ok(None); + } + + // RV32 trace linkage: the prover appends time-combined openings for selected CPU trace columns + // to the CCS ME output at r_time. We use those to bind Twist instances (PROG/REG/RAM) to the + // same trace, without embedding a shared CPU bus tail. + let trace = Rv32TraceLayout::new(); + let trace_cols_to_open: Vec = vec![ + trace.active, + trace.cycle, + trace.pc_before, + trace.instr_word, + trace.rs1_addr, + trace.rs1_val, + trace.rs2_addr, + trace.rs2_val, + trace.rd_addr, + trace.rd_val, + trace.ram_addr, + trace.ram_rv, + trace.ram_wv, + trace.shout_has_lookup, + trace.shout_val, + trace.shout_lhs, + trace.shout_rhs, + ]; + + let m_in = step.mcs_inst.m_in; + let t_len = step + .mem_insts + .first() + .map(|inst| inst.steps) + .or_else(|| { + // Shout event-table instances may have `steps != t_len`; prefer a non-event-table + // instance if present, otherwise fall back to inferring from the trace layout. + step.lut_insts + .iter() + .find(|inst| !matches!(inst.table_spec, Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. }))) + .map(|inst| inst.steps) + }) + .or_else(|| { + // Trace CCS layout inference: z = [x (m_in) | trace_cols * t_len] + let w = m.checked_sub(m_in)?; + if trace.cols == 0 || w % trace.cols != 0 { + return None; + } + Some(w / trace.cols) + }) + .ok_or_else(|| PiCcsError::InvalidInput("missing mem/lut instances".into()))?; + if t_len == 0 { + return Err(PiCcsError::InvalidInput("trace linkage requires steps>=1".into())); + } + for (i, inst) in step.mem_insts.iter().enumerate() { + if inst.steps != t_len { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage requires stable steps across mem instances (mem_idx={i} has steps={}, expected {t_len})", + inst.steps + ))); + } + } + let trace_len = trace + .cols + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace cols * t_len overflow".into()))?; + let expected_m = m_in + .checked_add(trace_len) + .ok_or_else(|| PiCcsError::InvalidInput("m_in + trace_len overflow".into()))?; + if m < expected_m { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage expects m >= m_in + trace.cols*t_len (m={}; min_m={expected_m} for t_len={t_len}, trace_cols={})", + m, trace.cols + ))); + } + let expected_y_len = core_t + .checked_add(y_prefix_cols) + .and_then(|v| v.checked_add(trace_cols_to_open.len())) + .ok_or_else(|| PiCcsError::InvalidInput("core_t + y_prefix_cols + trace_openings overflow".into()))?; + if ccs_out0.y_scalars.len() != expected_y_len { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage expects CPU ME output to contain exactly core_t + y_prefix_cols + trace_openings y_scalars (have {}, expected {expected_y_len})", + ccs_out0.y_scalars.len(), + ))); + } + let cpu_open = |idx: usize| -> Result { + ccs_out0 + .y_scalars + .get(core_t + y_prefix_cols + idx) + .copied() + .ok_or_else(|| PiCcsError::ProtocolError("missing CPU trace linkage opening".into())) + }; + + Ok(Some(TraceCpuLinkOpenings { + shout_has_lookup: cpu_open(13)?, + shout_val: cpu_open(14)?, + shout_lhs: cpu_open(15)?, + shout_rhs: cpu_open(16)?, + })) +} + +pub(crate) fn expected_trace_shout_table_id_from_openings( + core_t: usize, + step: &StepInstanceBundle, + mem_proof: &MemSidecarProof, + r_time: &[K], +) -> Result { + if !decode_stage_required_for_step_instance(step) { + return Ok(K::ZERO); + } + + if mem_proof.wp_me_claims.len() != 1 { + return Err(PiCcsError::ProtocolError( + "decode-linked Shout table_id check requires one WP ME claim".into(), + )); + } + let wp_me = &mem_proof.wp_me_claims[0]; + if wp_me.r.as_slice() != r_time { + return Err(PiCcsError::ProtocolError( + "decode-linked Shout table_id check: WP ME r mismatch".into(), + )); + } + if wp_me.c != step.mcs_inst.c { + return Err(PiCcsError::ProtocolError( + "decode-linked Shout table_id check: WP ME commitment mismatch".into(), + )); + } + if wp_me.m_in != step.mcs_inst.m_in { + return Err(PiCcsError::ProtocolError( + "decode-linked Shout table_id check: WP ME m_in mismatch".into(), + )); + } + + let trace = Rv32TraceLayout::new(); + let decode_layout = Rv32DecodeSidecarLayout::new(); + let wp_cols = rv32_trace_wp_opening_columns(&trace); + let control_extra_cols = if control_stage_required_for_step_instance(step) { + rv32_trace_control_extra_opening_columns(&trace) + } else { + Vec::new() + }; + let decode_open_cols = rv32_decode_lookup_backed_cols(&decode_layout); + + let decode_open_start = core_t + .checked_add(wp_cols.len()) + .and_then(|v| v.checked_add(control_extra_cols.len())) + .ok_or_else(|| { + PiCcsError::InvalidInput("decode-linked Shout table_id check: decode_open_start overflow".into()) + })?; + let decode_open_end = decode_open_start + .checked_add(decode_open_cols.len()) + .ok_or_else(|| { + PiCcsError::InvalidInput("decode-linked Shout table_id check: decode_open_end overflow".into()) + })?; + if wp_me.y_scalars.len() < decode_open_end { + return Err(PiCcsError::ProtocolError(format!( + "decode-linked Shout table_id check: missing decode openings (got {}, need at least {decode_open_end})", + wp_me.y_scalars.len() + ))); + } + + let decode_open = &wp_me.y_scalars[decode_open_start..decode_open_end]; + let decode_open_col = |col_id: usize| -> Result { + let idx = decode_open_cols + .iter() + .position(|&c| c == col_id) + .ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "decode-linked Shout table_id check: missing decode opening col {col_id}" + )) + })?; + Ok(decode_open[idx]) + }; + + Ok(decode_open_col(decode_layout.shout_table_id)?) +} + +pub(crate) fn prove_twist_addr_pre_time( + tr: &mut Poseidon2Transcript, + params: &NeoParams, + step: &StepWitnessBundle, + cpu_bus: &BusLayout, + ell_n: usize, + r_cycle: &[K], +) -> Result, PiCcsError> { + if step.mem_instances.is_empty() { + return Ok(Vec::new()); + } + let mut out = Vec::with_capacity(step.mem_instances.len()); + + let cpu_z_k = crate::memory_sidecar::cpu_bus::decode_cpu_z_to_k(params, &step.mcs.1.Z); + if cpu_bus.shout_cols.len() != step.lut_instances.len() || cpu_bus.twist_cols.len() != step.mem_instances.len() { + return Err(PiCcsError::InvalidInput( + "shared_cpu_bus layout mismatch for step (instance counts)".into(), + )); + } + + for (idx, (mem_inst, _mem_wit)) in step.mem_instances.iter().enumerate() { + neo_memory::addr::validate_twist_bit_addressing(mem_inst)?; + let pow2_cycle = 1usize << ell_n; + if mem_inst.steps > pow2_cycle { + return Err(PiCcsError::InvalidInput(format!( + "Twist(Route A): steps={} exceeds 2^ell_cycle={pow2_cycle}", + mem_inst.steps + ))); + } + + let bus = cpu_bus.clone(); + let z = cpu_z_k.clone(); + + let ell_addr = mem_inst.d * mem_inst.ell; + let expected_lanes = mem_inst.lanes.max(1); + let twist_inst_cols = bus.twist_cols.get(idx).ok_or_else(|| { + PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch: missing twist_cols for mem_idx={idx}" + )) + })?; + if twist_inst_cols.lanes.len() != expected_lanes { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at mem_idx={idx}: expected lanes={expected_lanes}, got {}", + twist_inst_cols.lanes.len() + ))); + } + + let mut lanes: Vec = Vec::with_capacity(twist_inst_cols.lanes.len()); + for (lane_idx, twist_cols) in twist_inst_cols.lanes.iter().enumerate() { + if twist_cols.ra_bits.end - twist_cols.ra_bits.start != ell_addr + || twist_cols.wa_bits.end - twist_cols.wa_bits.start != ell_addr + { + return Err(PiCcsError::InvalidInput(format!( + "shared_cpu_bus layout mismatch at mem_idx={idx}, lane={lane_idx}: expected ell_addr={ell_addr}" + ))); + } + + let mut ra_bits = Vec::with_capacity(ell_addr); + for col_id in twist_cols.ra_bits.clone() { + ra_bits.push(crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( + &z, + &bus, + col_id, + mem_inst.steps, + pow2_cycle, + )?); + } + + let mut wa_bits = Vec::with_capacity(ell_addr); + for col_id in twist_cols.wa_bits.clone() { + wa_bits.push(crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( + &z, + &bus, + col_id, + mem_inst.steps, + pow2_cycle, + )?); + } + + let has_read = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( + &z, + &bus, + twist_cols.has_read, + mem_inst.steps, + pow2_cycle, + )?; + let has_write = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( + &z, + &bus, + twist_cols.has_write, + mem_inst.steps, + pow2_cycle, + )?; + let wv = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( + &z, + &bus, + twist_cols.wv, + mem_inst.steps, + pow2_cycle, + )?; + let rv = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( + &z, + &bus, + twist_cols.rv, + mem_inst.steps, + pow2_cycle, + )?; + let inc_at_write_addr = crate::memory_sidecar::cpu_bus::build_time_sparse_from_bus_col( + &z, + &bus, + twist_cols.inc, + mem_inst.steps, + pow2_cycle, + )?; + + lanes.push(TwistLaneSparseCols { + ra_bits, + wa_bits, + has_read, + has_write, + wv, + rv, + inc_at_write_addr, + }); + } + + let decoded = TwistDecodedColsSparse { lanes }; + + let init_sparse: Vec<(usize, K)> = match &mem_inst.init { + MemInit::Zero => Vec::new(), + MemInit::Sparse(pairs) => pairs + .iter() + .map(|(addr, val)| { + let addr_usize = usize::try_from(*addr).map_err(|_| { + PiCcsError::InvalidInput(format!("Twist: init address doesn't fit usize: addr={addr}")) + })?; + if addr_usize >= mem_inst.k { + return Err(PiCcsError::InvalidInput(format!( + "Twist: init address out of range: addr={addr} >= k={}", + mem_inst.k + ))); + } + Ok((addr_usize, (*val).into())) + }) + .collect::>()?, + }; + + let mut read_addr_oracle = + TwistReadCheckAddrOracleSparseTimeMultiLane::new(init_sparse.clone(), r_cycle, &decoded.lanes); + let mut write_addr_oracle = + TwistWriteCheckAddrOracleSparseTimeMultiLane::new(init_sparse, r_cycle, &decoded.lanes); + + let labels: [&[u8]; 2] = [b"twist/read_addr_pre".as_slice(), b"twist/write_addr_pre".as_slice()]; + let claimed_sums = vec![K::ZERO, K::ZERO]; + tr.append_message(b"twist/addr_pre_time/claim_idx", &(idx as u64).to_le_bytes()); + bind_batched_claim_sums(tr, b"twist/addr_pre_time/claimed_sums", &claimed_sums, &labels); + + let mut claims = [ + BatchedClaim { + oracle: &mut read_addr_oracle, + claimed_sum: K::ZERO, + label: labels[0], + }, + BatchedClaim { + oracle: &mut write_addr_oracle, + claimed_sum: K::ZERO, + label: labels[1], + }, + ]; + + let (r_addr, per_claim_results) = run_batched_sumcheck_prover_ds(tr, b"twist/addr_pre_time", idx, &mut claims)?; + if per_claim_results.len() != 2 { + return Err(PiCcsError::ProtocolError(format!( + "twist addr-pre per-claim results len()={}, expected 2", + per_claim_results.len() + ))); + } + + out.push(TwistAddrPreProverData { + addr_pre: BatchedAddrProof { + claimed_sums, + round_polys: vec![ + per_claim_results[0].round_polys.clone(), + per_claim_results[1].round_polys.clone(), + ], + r_addr: r_addr.clone(), + }, + decoded, + read_check_claim_sum: per_claim_results[0].final_value, + write_check_claim_sum: per_claim_results[1].final_value, + }); + } + + Ok(out) +} diff --git a/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs b/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs new file mode 100644 index 00000000..1c1cef55 --- /dev/null +++ b/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs @@ -0,0 +1,1494 @@ +use super::*; + +// ============================================================================ +// Transcript binding +// ============================================================================ + +pub(crate) fn bind_shout_table_spec(tr: &mut Poseidon2Transcript, spec: &Option) { + let Some(spec) = spec else { + return; + }; + + tr.append_message(b"shout/table_spec/tag", &[1u8]); + match spec { + LutTableSpec::RiscvOpcode { opcode, xlen } => { + let opcode_id = neo_memory::riscv::lookups::RiscvShoutTables::new(*xlen) + .opcode_to_id(*opcode) + .0 as u64; + + tr.append_message(b"shout/table_spec/riscv/tag", &[1u8]); + tr.append_message(b"shout/table_spec/riscv/opcode_id", &opcode_id.to_le_bytes()); + tr.append_message(b"shout/table_spec/riscv/xlen", &(*xlen as u64).to_le_bytes()); + } + LutTableSpec::RiscvOpcodePacked { opcode, xlen } => { + let opcode_id = neo_memory::riscv::lookups::RiscvShoutTables::new(*xlen) + .opcode_to_id(*opcode) + .0 as u64; + + tr.append_message(b"shout/table_spec/riscv_packed/tag", &[1u8]); + tr.append_message(b"shout/table_spec/riscv_packed/opcode_id", &opcode_id.to_le_bytes()); + tr.append_message(b"shout/table_spec/riscv_packed/xlen", &(*xlen as u64).to_le_bytes()); + } + LutTableSpec::RiscvOpcodeEventTablePacked { + opcode, + xlen, + time_bits, + } => { + let opcode_id = neo_memory::riscv::lookups::RiscvShoutTables::new(*xlen) + .opcode_to_id(*opcode) + .0 as u64; + + tr.append_message(b"shout/table_spec/riscv_event_table_packed/tag", &[1u8]); + tr.append_message( + b"shout/table_spec/riscv_event_table_packed/opcode_id", + &opcode_id.to_le_bytes(), + ); + tr.append_message( + b"shout/table_spec/riscv_event_table_packed/xlen", + &(*xlen as u64).to_le_bytes(), + ); + tr.append_message( + b"shout/table_spec/riscv_event_table_packed/time_bits", + &(*time_bits as u64).to_le_bytes(), + ); + } + LutTableSpec::IdentityU32 => { + tr.append_message(b"shout/table_spec/identity_u32/tag", &[1u8]); + } + } +} + +pub(crate) fn absorb_step_memory_impl<'a, LI, MI>(tr: &mut Poseidon2Transcript, mut lut_insts: LI, mut mem_insts: MI) +where + LI: ExactSizeIterator>, + MI: ExactSizeIterator>, +{ + tr.append_message(b"step/absorb_memory_start", &[]); + tr.append_message(b"step/lut_count", &(lut_insts.len() as u64).to_le_bytes()); + for (i, inst) in lut_insts.by_ref().enumerate() { + // Bind public LUT parameters before any challenges. + tr.append_message(b"step/lut_idx", &(i as u64).to_le_bytes()); + tr.append_message(b"shout/table_id", &(inst.table_id as u64).to_le_bytes()); + tr.append_message(b"shout/k", &(inst.k as u64).to_le_bytes()); + tr.append_message(b"shout/d", &(inst.d as u64).to_le_bytes()); + tr.append_message(b"shout/n_side", &(inst.n_side as u64).to_le_bytes()); + tr.append_message(b"shout/steps", &(inst.steps as u64).to_le_bytes()); + tr.append_message(b"shout/ell", &(inst.ell as u64).to_le_bytes()); + tr.append_message(b"shout/lanes", &(inst.lanes.max(1) as u64).to_le_bytes()); + bind_shout_table_spec(tr, &inst.table_spec); + let table_digest = digest_fields(b"shout/table", &inst.table); + tr.append_message(b"shout/table_digest", &table_digest); + + // Bind commitments so Route-A challenges (r_cycle, addr/time points) are sampled after them. + tr.append_message(b"shout/comms_len", &(inst.comms.len() as u64).to_le_bytes()); + for (j, comm) in inst.comms.iter().enumerate() { + tr.append_message(b"shout/comm_idx", &(j as u64).to_le_bytes()); + tr.append_fields(b"shout/comm_data", &comm.data); + } + } + tr.append_message(b"step/mem_count", &(mem_insts.len() as u64).to_le_bytes()); + for (i, inst) in mem_insts.by_ref().enumerate() { + // Bind public memory parameters before any challenges. + tr.append_message(b"step/mem_idx", &(i as u64).to_le_bytes()); + tr.append_message(b"twist/mem_id", &(inst.mem_id as u64).to_le_bytes()); + tr.append_message(b"twist/k", &(inst.k as u64).to_le_bytes()); + tr.append_message(b"twist/d", &(inst.d as u64).to_le_bytes()); + tr.append_message(b"twist/n_side", &(inst.n_side as u64).to_le_bytes()); + tr.append_message(b"twist/steps", &(inst.steps as u64).to_le_bytes()); + tr.append_message(b"twist/ell", &(inst.ell as u64).to_le_bytes()); + tr.append_message(b"twist/lanes", &(inst.lanes.max(1) as u64).to_le_bytes()); + let init_digest = match &inst.init { + MemInit::Zero => digest_fields(b"twist/init/zero", &[]), + MemInit::Sparse(pairs) => { + let mut fs = Vec::with_capacity(2 * pairs.len()); + for (addr, val) in pairs.iter() { + fs.push(F::from_u64(*addr)); + fs.push(*val); + } + digest_fields(b"twist/init/sparse", &fs) + } + }; + tr.append_message(b"twist/init_digest", &init_digest); + + // Bind commitments so Route-A challenges (r_cycle, addr/time points) are sampled after them. + tr.append_message(b"twist/comms_len", &(inst.comms.len() as u64).to_le_bytes()); + for (j, comm) in inst.comms.iter().enumerate() { + tr.append_message(b"twist/comm_idx", &(j as u64).to_le_bytes()); + tr.append_fields(b"twist/comm_data", &comm.data); + } + } + tr.append_message(b"step/absorb_memory_done", &[]); +} + +pub fn absorb_step_memory(tr: &mut Poseidon2Transcript, step: &StepInstanceBundle) { + absorb_step_memory_impl(tr, step.lut_insts.iter(), step.mem_insts.iter()); +} + +pub(crate) fn absorb_step_memory_witness(tr: &mut Poseidon2Transcript, step: &StepWitnessBundle) { + absorb_step_memory_impl( + tr, + step.lut_instances.iter().map(|(inst, _)| inst), + step.mem_instances.iter().map(|(inst, _)| inst), + ); +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum Rv32PackedShoutOp { + And, + Andn, + Add, + Or, + Sub, + Xor, + Eq, + Neq, + Slt, + Sll, + Srl, + Sra, + Sltu, + Mul, + Mulh, + Mulhu, + Mulhsu, + Div, + Divu, + Rem, + Remu, +} + +pub(crate) fn rv32_packed_shout_layout( + spec: &Option, +) -> Result, PiCcsError> { + let (opcode, xlen, time_bits) = match spec { + Some(LutTableSpec::RiscvOpcodePacked { opcode, xlen }) => (*opcode, *xlen, 0usize), + Some(LutTableSpec::RiscvOpcodeEventTablePacked { + opcode, + xlen, + time_bits, + }) => (*opcode, *xlen, *time_bits), + _ => return Ok(None), + }; + + if xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "packed RISC-V Shout is only supported for RV32 (xlen=32) in Route A (got xlen={xlen})" + ))); + } + if time_bits == 0 { + // `RiscvOpcodePacked` uses `time_bits=0` (no prefix). Event-table packed must be >= 1. + if matches!(spec, Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. })) { + return Err(PiCcsError::InvalidInput( + "RiscvOpcodeEventTablePacked requires time_bits >= 1".into(), + )); + } + } + + let op = match opcode { + neo_memory::riscv::lookups::RiscvOpcode::And => Rv32PackedShoutOp::And, + neo_memory::riscv::lookups::RiscvOpcode::Andn => Rv32PackedShoutOp::Andn, + neo_memory::riscv::lookups::RiscvOpcode::Add => Rv32PackedShoutOp::Add, + neo_memory::riscv::lookups::RiscvOpcode::Or => Rv32PackedShoutOp::Or, + neo_memory::riscv::lookups::RiscvOpcode::Sub => Rv32PackedShoutOp::Sub, + neo_memory::riscv::lookups::RiscvOpcode::Xor => Rv32PackedShoutOp::Xor, + neo_memory::riscv::lookups::RiscvOpcode::Eq => Rv32PackedShoutOp::Eq, + neo_memory::riscv::lookups::RiscvOpcode::Neq => Rv32PackedShoutOp::Neq, + neo_memory::riscv::lookups::RiscvOpcode::Slt => Rv32PackedShoutOp::Slt, + neo_memory::riscv::lookups::RiscvOpcode::Sll => Rv32PackedShoutOp::Sll, + neo_memory::riscv::lookups::RiscvOpcode::Srl => Rv32PackedShoutOp::Srl, + neo_memory::riscv::lookups::RiscvOpcode::Sra => Rv32PackedShoutOp::Sra, + neo_memory::riscv::lookups::RiscvOpcode::Sltu => Rv32PackedShoutOp::Sltu, + neo_memory::riscv::lookups::RiscvOpcode::Mul => Rv32PackedShoutOp::Mul, + neo_memory::riscv::lookups::RiscvOpcode::Mulh => Rv32PackedShoutOp::Mulh, + neo_memory::riscv::lookups::RiscvOpcode::Mulhu => Rv32PackedShoutOp::Mulhu, + neo_memory::riscv::lookups::RiscvOpcode::Mulhsu => Rv32PackedShoutOp::Mulhsu, + neo_memory::riscv::lookups::RiscvOpcode::Div => Rv32PackedShoutOp::Div, + neo_memory::riscv::lookups::RiscvOpcode::Divu => Rv32PackedShoutOp::Divu, + neo_memory::riscv::lookups::RiscvOpcode::Rem => Rv32PackedShoutOp::Rem, + neo_memory::riscv::lookups::RiscvOpcode::Remu => Rv32PackedShoutOp::Remu, + _ => { + return Err(PiCcsError::InvalidInput(format!( + "packed RISC-V Shout is only supported for selected RV32 ops in Route A (got opcode={opcode:?})" + ))); + } + }; + + Ok(Some((op, time_bits))) +} + +pub(crate) fn rv32_shout_table_id_from_spec(spec: &Option) -> Result { + let (opcode, xlen) = match spec { + Some(LutTableSpec::RiscvOpcode { opcode, xlen }) => (*opcode, *xlen), + Some(LutTableSpec::RiscvOpcodePacked { opcode, xlen }) => (*opcode, *xlen), + Some(LutTableSpec::RiscvOpcodeEventTablePacked { opcode, xlen, .. }) => (*opcode, *xlen), + Some(LutTableSpec::IdentityU32) => { + return Err(PiCcsError::InvalidInput( + "trace linkage expects RISC-V shout table specs (IdentityU32 is unsupported)".into(), + )); + } + None => { + return Err(PiCcsError::InvalidInput( + "trace linkage requires LutTableSpec on Shout instances".into(), + )); + } + }; + + if xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage expects RV32 shout specs (got xlen={xlen})" + ))); + } + Ok(neo_memory::riscv::lookups::RiscvShoutTables::new(xlen) + .opcode_to_id(opcode) + .0) +} + +pub(crate) fn rv32_trace_link_table_id_from_spec(spec: &Option) -> Result, PiCcsError> { + match spec { + Some(LutTableSpec::RiscvOpcode { .. }) + | Some(LutTableSpec::RiscvOpcodePacked { .. }) + | Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. }) => Ok(Some(rv32_shout_table_id_from_spec(spec)?)), + Some(LutTableSpec::IdentityU32) | None => Ok(None), + } +} + +// ============================================================================ +// Prover helpers +// ============================================================================ + +pub(crate) struct ShoutDecodedColsSparse { + pub lanes: Vec, +} + +pub(crate) struct ShoutLaneSparseCols { + pub addr_bits: Vec>, + pub has_lookup: SparseIdxVec, + pub val: SparseIdxVec, +} + +pub(crate) struct TwistDecodedColsSparse { + pub lanes: Vec, +} + +pub(crate) struct SumRoundOracle { + oracles: Vec>, + num_rounds: usize, + degree_bound: usize, +} + +impl SumRoundOracle { + pub(crate) fn new(oracles: Vec>) -> Self { + if oracles.is_empty() { + panic!("SumRoundOracle requires at least one oracle"); + } + + let num_rounds = oracles[0].num_rounds(); + let degree_bound = oracles[0].degree_bound(); + for (idx, o) in oracles.iter().enumerate().skip(1) { + if o.num_rounds() != num_rounds { + panic!( + "SumRoundOracle num_rounds mismatch at idx={idx} (got {}, expected {num_rounds})", + o.num_rounds() + ); + } + if o.degree_bound() != degree_bound { + panic!( + "SumRoundOracle degree_bound mismatch at idx={idx} (got {}, expected {degree_bound})", + o.degree_bound() + ); + } + } + + Self { + oracles, + num_rounds, + degree_bound, + } + } +} + +impl RoundOracle for SumRoundOracle { + fn evals_at(&mut self, points: &[K]) -> Vec { + let mut acc = vec![K::ZERO; points.len()]; + for o in self.oracles.iter_mut() { + let ys = o.evals_at(points); + if ys.len() != acc.len() { + panic!( + "SumRoundOracle eval length mismatch (got {}, expected {})", + ys.len(), + acc.len() + ); + } + for (a, y) in acc.iter_mut().zip(ys) { + *a += y; + } + } + acc + } + + fn num_rounds(&self) -> usize { + self.num_rounds + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + for o in self.oracles.iter_mut() { + o.fold(r); + } + self.num_rounds = self.oracles[0].num_rounds(); + } +} + +#[inline] +pub(crate) fn interp(a0: K, a1: K, x: K) -> K { + a0 + (a1 - a0) * x +} + +pub(crate) fn log2_pow2(n: usize) -> usize { + if n == 0 { + return 0; + } + debug_assert!(n.is_power_of_two(), "expected power of two, got {n}"); + n.trailing_zeros() as usize +} + +pub(crate) fn gather_pairs_from_sparse(entries: &[(usize, K)]) -> Vec { + let mut out: Vec = Vec::with_capacity(entries.len()); + let mut prev: Option = None; + for &(idx, _v) in entries { + let p = idx >> 1; + if prev != Some(p) { + out.push(p); + prev = Some(p); + } + } + out +} + +/// Sparse time-domain oracle for event-table RV32 Shout hash linkage: +/// Σ_t has_lookup(t) · (1 + α·val(t) + β·lhs(t) + γ·rhs(t)) · Π_b eq(time_bit_b(t), r_addr_b) +/// +/// Intended usage: +/// - `time_bit_b(t)` encodes the original cycle index of event row `t` (little-endian). +/// - `r_addr` is set to `r_cycle` so the claim is an MLE evaluation over cycle indices. +pub(crate) struct ShoutEventTableHashOracleSparseTime { + degree_bound: usize, + r_addr: Vec, + + time_bits: Vec>, + has_lookup: SparseIdxVec, + val: SparseIdxVec, + lhs: SparseIdxVec, + rhs_terms: Vec<(SparseIdxVec, K)>, + + alpha: K, + beta: K, + gamma: K, +} + +impl ShoutEventTableHashOracleSparseTime { + pub(crate) fn new( + r_addr: &[K], + time_bits: Vec>, + has_lookup: SparseIdxVec, + val: SparseIdxVec, + lhs: SparseIdxVec, + rhs_terms: Vec<(SparseIdxVec, K)>, + alpha: K, + beta: K, + gamma: K, + ) -> (Self, K) { + let ell_n = log2_pow2(has_lookup.len()); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + for (i, col) in time_bits.iter().enumerate() { + debug_assert_eq!(col.len(), 1usize << ell_n, "time_bits[{i}] length mismatch"); + } + for (i, (col, _w)) in rhs_terms.iter().enumerate() { + debug_assert_eq!(col.len(), 1usize << ell_n, "rhs_terms[{i}] length mismatch"); + } + debug_assert_eq!(time_bits.len(), r_addr.len(), "time_bits/r_addr length mismatch"); + + let mut claim = K::ZERO; + for &(t, gate) in has_lookup.entries() { + if gate == K::ZERO { + continue; + } + + let v_t = val.get(t); + let lhs_t = lhs.get(t); + let mut rhs_t = K::ZERO; + for (col, w) in rhs_terms.iter() { + rhs_t += *w * col.get(t); + } + + let hash_t = K::ONE + alpha * v_t + beta * lhs_t + gamma * rhs_t; + if hash_t == K::ZERO { + continue; + } + + let mut eq_addr = K::ONE; + for (b, col) in time_bits.iter().enumerate() { + eq_addr *= eq_bit_affine(col.get(t), r_addr[b]); + } + + claim += gate * hash_t * eq_addr; + } + + ( + Self { + degree_bound: 2 + r_addr.len(), + r_addr: r_addr.to_vec(), + time_bits, + has_lookup, + val, + lhs, + rhs_terms, + alpha, + beta, + gamma, + }, + claim, + ) + } +} + +impl RoundOracle for ShoutEventTableHashOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let v = self.val.singleton_value(); + let lhs = self.lhs.singleton_value(); + let mut rhs = K::ZERO; + for (col, w) in self.rhs_terms.iter() { + rhs += *w * col.singleton_value(); + } + let hash = gate * (K::ONE + self.alpha * v + self.beta * lhs + self.gamma * rhs); + + let mut eq_addr = K::ONE; + for (b, col) in self.time_bits.iter().enumerate() { + eq_addr *= eq_bit_affine(col.singleton_value(), self.r_addr[b]); + } + + let out = hash * eq_addr; + return vec![out; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let v0 = self.val.get(child0); + let v1 = self.val.get(child1); + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + + let mut rhs0 = K::ZERO; + let mut rhs1 = K::ZERO; + for (col, w) in self.rhs_terms.iter() { + rhs0 += *w * col.get(child0); + rhs1 += *w * col.get(child1); + } + + let mut eq0s: Vec = Vec::with_capacity(self.time_bits.len()); + let mut d_eqs: Vec = Vec::with_capacity(self.time_bits.len()); + for (b, col) in self.time_bits.iter().enumerate() { + let e0 = eq_bit_affine(col.get(child0), self.r_addr[b]); + let e1 = eq_bit_affine(col.get(child1), self.r_addr[b]); + eq0s.push(e0); + d_eqs.push(e1 - e0); + } + + for (i, &x) in points.iter().enumerate() { + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let v_x = interp(v0, v1, x); + let lhs_x = interp(lhs0, lhs1, x); + let rhs_x = interp(rhs0, rhs1, x); + + let mut prod = gate_x * (K::ONE + self.alpha * v_x + self.beta * lhs_x + self.gamma * rhs_x); + for (e0, de) in eq0s.iter().zip(d_eqs.iter()) { + prod *= *e0 + *de * x; + } + ys[i] += prod; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + log2_pow2(self.has_lookup.len()) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.has_lookup.fold_round_in_place(r); + self.val.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + for (col, _w) in self.rhs_terms.iter_mut() { + col.fold_round_in_place(r); + } + for col in self.time_bits.iter_mut() { + col.fold_round_in_place(r); + } + } +} + +pub(crate) fn build_twist_inc_terms_at_r_addr(lanes: &[TwistLaneSparseCols], r_addr: &[K]) -> Vec<(usize, K)> { + let ell_addr = r_addr.len(); + let mut out: Vec<(usize, K)> = Vec::new(); + + for lane in lanes.iter() { + debug_assert_eq!(lane.wa_bits.len(), ell_addr, "wa_bits len mismatch"); + for &(t, has_w) in lane.has_write.entries() { + let inc_t = lane.inc_at_write_addr.get(t); + if has_w == K::ZERO || inc_t == K::ZERO { + continue; + } + + let mut eq_addr = K::ONE; + for (b, col) in lane.wa_bits.iter().enumerate() { + let bit = col.get(t); + eq_addr *= eq_bit_affine(bit, r_addr[b]); + } + + let inc_at_r_addr = has_w * inc_t * eq_addr; + if inc_at_r_addr != K::ZERO { + out.push((t, inc_at_r_addr)); + } + } + } + + out +} + +pub struct RouteAShoutTimeOracles { + pub lanes: Vec, + pub bitness: Vec>, +} + +pub struct RouteAShoutTimeLaneOracles { + pub value: Box, + pub value_claim: K, + pub adapter: Box, + pub adapter_claim: K, + pub event_table_hash: Option>, + pub event_table_hash_claim: Option, + pub gamma_group: Option, +} + +pub struct RouteAShoutGammaGroupOracles { + pub value: Box, + pub value_claim: K, + pub adapter: Box, + pub adapter_claim: K, +} + +pub struct RouteATwistTimeOracles { + pub read_check: Box, + pub write_check: Box, + pub bitness: Vec>, +} + +pub struct RouteAMemoryOracles { + pub shout: Vec, + pub shout_gamma_groups: Vec, + pub shout_event_trace_hash: Option, + pub twist: Vec, +} + +pub struct RouteAShoutEventTraceHashOracle { + pub oracle: Box, + pub claim: K, +} + +pub trait TimeBatchedClaims { + fn append_time_claims<'a>( + &'a mut self, + ell_n: usize, + claimed_sums: &mut Vec, + degree_bounds: &mut Vec, + labels: &mut Vec<&'static [u8]>, + claim_is_dynamic: &mut Vec, + claims: &mut Vec>, + ); +} + +pub(crate) struct ShoutAddrPreBatchProverData { + pub addr_pre: ShoutAddrPreProof, + pub decoded: Vec, +} + +#[derive(Clone, Debug)] +pub struct ShoutAddrPreVerifyData { + pub is_active: bool, + pub addr_claim_sum: K, + pub addr_final: K, + pub r_addr: Vec, + pub table_eval_at_r_addr: K, +} + +pub(crate) struct TwistAddrPreProverData { + pub addr_pre: BatchedAddrProof, + pub decoded: TwistDecodedColsSparse, + /// Time-lane claimed sum for the read-check oracle (output of addr-pre). + pub read_check_claim_sum: K, + /// Time-lane claimed sum for the write-check oracle (output of addr-pre). + pub write_check_claim_sum: K, +} + +pub struct TwistAddrPreVerifyData { + pub r_addr: Vec, + pub read_check_claim_sum: K, + pub write_check_claim_sum: K, +} + +#[derive(Clone, Debug)] +pub struct TwistTimeLaneOpeningsLane { + pub wa_bits: Vec, + pub has_write: K, + pub inc_at_write_addr: K, +} + +#[derive(Clone, Debug)] +pub struct TwistTimeLaneOpenings { + pub lanes: Vec, +} + +#[derive(Clone, Debug)] +pub struct RouteAMemoryVerifyOutput { + pub claim_idx_end: usize, + pub twist_time_openings: Vec, +} + +#[derive(Clone, Copy)] +pub(crate) struct TraceCpuLinkOpenings { + pub(crate) shout_has_lookup: K, + pub(crate) shout_val: K, + pub(crate) shout_lhs: K, + pub(crate) shout_rhs: K, +} + +#[derive(Clone, Copy, Debug, Default)] +pub(crate) struct ShoutTraceLinkSums { + pub(crate) has_lookup: K, + pub(crate) val: K, + pub(crate) lhs: K, + pub(crate) rhs: K, + pub(crate) table_id: K, +} + +#[inline] +pub(crate) fn verify_non_event_trace_shout_linkage( + cpu: TraceCpuLinkOpenings, + sums: ShoutTraceLinkSums, + expected_table_id: Option, +) -> Result<(), PiCcsError> { + if sums.has_lookup != cpu.shout_has_lookup { + return Err(PiCcsError::ProtocolError( + "trace linkage failed: Shout has_lookup mismatch".into(), + )); + } + if sums.val != cpu.shout_val { + return Err(PiCcsError::ProtocolError( + "trace linkage failed: Shout val mismatch".into(), + )); + } + if sums.lhs != cpu.shout_lhs { + return Err(PiCcsError::ProtocolError( + "trace linkage failed: Shout lhs mismatch".into(), + )); + } + if sums.rhs != cpu.shout_rhs { + return Err(PiCcsError::ProtocolError( + "trace linkage failed: Shout rhs mismatch".into(), + )); + } + if let Some(expected_table_id) = expected_table_id { + if sums.table_id != expected_table_id { + return Err(PiCcsError::ProtocolError( + "trace linkage failed: Shout table_id mismatch".into(), + )); + } + } + Ok(()) +} + +#[inline] +pub(crate) fn eq_single_k(a: K, b: K) -> K { + a * b + (K::ONE - a) * (K::ONE - b) +} + +pub(crate) fn chi_cycle_children(r_cycle: &[K], bit_idx: usize, prefix_eq: K, pair_idx: usize) -> (K, K) { + let mut suffix = K::ONE; + let mut shift = bit_idx + 1; + let mut idx = pair_idx; + while shift < r_cycle.len() { + let bit = idx & 1; + let bit_k = if bit == 1 { K::ONE } else { K::ZERO }; + suffix *= eq_bit_affine(bit_k, r_cycle[shift]); + idx >>= 1; + shift += 1; + } + + let r = r_cycle[bit_idx]; + let child0 = prefix_eq * (K::ONE - r) * suffix; + let child1 = prefix_eq * r * suffix; + (child0, child1) +} + +#[inline] +pub(crate) fn wb_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x5742_5F42_4F4F_4Cu64) +} + +#[inline] +pub(crate) fn w2_decode_pack_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x5732_5F50_4143_4Bu64) +} + +#[inline] +pub(crate) fn w2_decode_imm_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x5732_5F49_4D4D_214Du64) +} + +#[inline] +pub(crate) fn w3_bitness_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x5733_5F42_4954_2144u64) +} + +#[inline] +pub(crate) fn w3_quiescence_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x5733_5F51_5549_4553u64) +} + +#[inline] +pub(crate) fn w3_load_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x5733_5F4C_4F41_4421u64) +} + +#[inline] +pub(crate) fn w3_store_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x5733_5F53_544F_5245u64) +} + +#[inline] +pub(crate) fn control_next_pc_linear_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x4354_524C_4E50_434Cu64) +} + +#[inline] +pub(crate) fn control_next_pc_control_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x4354_524C_4E50_4343u64) +} + +#[inline] +pub(crate) fn control_branch_semantics_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x4354_524C_4252_534Du64) +} + +#[inline] +pub(crate) fn control_writeback_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x4354_524C_5752_4255u64) +} + +#[inline] +pub(crate) fn wp_weight_vector(r_cycle: &[K], len: usize) -> Vec { + bitness_weights(r_cycle, len, 0x5750_5F51_5549_4553u64) +} + +pub(crate) fn rv32_trace_wb_columns(layout: &Rv32TraceLayout) -> Vec { + vec![layout.active, layout.halted, layout.shout_has_lookup] +} + +pub(crate) const W2_FIELDS_RESIDUAL_COUNT: usize = 70; +pub(crate) const W2_IMM_RESIDUAL_COUNT: usize = 4; + +#[inline] +pub(crate) fn w2_bool01(v: K) -> K { + v * (v - K::ONE) +} + +#[inline] +pub(crate) fn w2_decode_selector_residuals( + active: K, + decode_opcode: K, + opcode_flags: [K; 12], + funct3_is: [K; 8], + funct3_bits: [K; 3], + op_amo: K, +) -> [K; 8] { + let opcode_one_hot = opcode_flags.into_iter().fold(K::ZERO, |acc, v| acc + v) - active; + let funct3_one_hot = funct3_is.into_iter().fold(K::ZERO, |acc, v| acc + v) - active; + let funct3_bit0_link = (funct3_is[1] + funct3_is[3] + funct3_is[5] + funct3_is[7]) - funct3_bits[0]; + let funct3_bit1_link = (funct3_is[2] + funct3_is[3] + funct3_is[6] + funct3_is[7]) - funct3_bits[1]; + let funct3_bit2_link = (funct3_is[4] + funct3_is[5] + funct3_is[6] + funct3_is[7]) - funct3_bits[2]; + let branch_f3b1_link = (funct3_is[6] + funct3_is[7]) - (funct3_bits[1] * funct3_bits[2]); + // Tier-2.1 trace mode lock: op_amo must be zero on every row. + let amo_forbidden = op_amo; + let opcode_value_link = opcode_flags[0] * K::from(F::from_u64(0x37)) + + opcode_flags[1] * K::from(F::from_u64(0x17)) + + opcode_flags[2] * K::from(F::from_u64(0x6f)) + + opcode_flags[3] * K::from(F::from_u64(0x67)) + + opcode_flags[4] * K::from(F::from_u64(0x63)) + + opcode_flags[5] * K::from(F::from_u64(0x03)) + + opcode_flags[6] * K::from(F::from_u64(0x23)) + + opcode_flags[7] * K::from(F::from_u64(0x13)) + + opcode_flags[8] * K::from(F::from_u64(0x33)) + + opcode_flags[9] * K::from(F::from_u64(0x0f)) + + opcode_flags[10] * K::from(F::from_u64(0x73)) + + opcode_flags[11] * K::from(F::from_u64(0x2f)) + - decode_opcode; + + [ + opcode_one_hot, + funct3_one_hot, + funct3_bit0_link, + funct3_bit1_link, + funct3_bit2_link, + branch_f3b1_link, + amo_forbidden, + opcode_value_link, + ] +} + +#[inline] +pub(crate) fn w2_decode_bitness_residuals(opcode_flags: [K; 12], funct3_is: [K; 8]) -> [K; 20] { + [ + w2_bool01(opcode_flags[0]), + w2_bool01(opcode_flags[1]), + w2_bool01(opcode_flags[2]), + w2_bool01(opcode_flags[3]), + w2_bool01(opcode_flags[4]), + w2_bool01(opcode_flags[5]), + w2_bool01(opcode_flags[6]), + w2_bool01(opcode_flags[7]), + w2_bool01(opcode_flags[8]), + w2_bool01(opcode_flags[9]), + w2_bool01(opcode_flags[10]), + w2_bool01(opcode_flags[11]), + w2_bool01(funct3_is[0]), + w2_bool01(funct3_is[1]), + w2_bool01(funct3_is[2]), + w2_bool01(funct3_is[3]), + w2_bool01(funct3_is[4]), + w2_bool01(funct3_is[5]), + w2_bool01(funct3_is[6]), + w2_bool01(funct3_is[7]), + ] +} + +#[inline] +pub(crate) fn w2_alu_branch_lookup_residuals( + active: K, + halted: K, + shout_has_lookup: K, + shout_lhs: K, + shout_rhs: K, + shout_table_id: K, + rs1_val: K, + rs2_val: K, + rd_has_write: K, + rd_is_zero: K, + rd_val: K, + ram_has_read: K, + ram_has_write: K, + ram_addr: K, + shout_val: K, + funct3_bits: [K; 3], + funct7_bits: [K; 7], + opcode_flags: [K; 12], + op_write_flags: [K; 6], + funct3_is: [K; 8], + alu_reg_table_delta: K, + alu_imm_table_delta: K, + alu_imm_shift_rhs_delta: K, + rs2_decode: K, + imm_i: K, + imm_s: K, +) -> [K; 42] { + let op_lui = opcode_flags[0]; + let op_auipc = opcode_flags[1]; + let op_jal = opcode_flags[2]; + let op_jalr = opcode_flags[3]; + let op_branch = opcode_flags[4]; + let op_alu_imm = opcode_flags[7]; + let op_alu_reg = opcode_flags[8]; + let op_misc_mem = opcode_flags[9]; + let op_system = opcode_flags[10]; + + let op_lui_write = op_write_flags[0]; + let op_auipc_write = op_write_flags[1]; + let op_jal_write = op_write_flags[2]; + let op_jalr_write = op_write_flags[3]; + let op_alu_imm_write = op_write_flags[4]; + let op_alu_reg_write = op_write_flags[5]; + + let non_mem_ops = + op_lui + op_auipc + op_jal + op_jalr + op_branch + op_alu_imm + op_alu_reg + op_misc_mem + op_system; + + let alu_table_base = K::from(F::from_u64(3)) * funct3_is[0] + + K::from(F::from_u64(7)) * funct3_is[1] + + K::from(F::from_u64(5)) * funct3_is[2] + + K::from(F::from_u64(6)) * funct3_is[3] + + K::from(F::from_u64(1)) * funct3_is[4] + + K::from(F::from_u64(8)) * funct3_is[5] + + K::from(F::from_u64(2)) * funct3_is[6]; + let branch_table_expected = + K::from(F::from_u64(10)) - K::from(F::from_u64(5)) * funct3_bits[2] + (funct3_bits[1] * funct3_bits[2]); + let shift_selector = funct3_is[1] + funct3_is[5]; + + [ + op_alu_imm * (shout_has_lookup - K::ONE), + op_alu_reg * (shout_has_lookup - K::ONE), + op_branch * (shout_has_lookup - K::ONE), + (K::ONE - shout_has_lookup) * shout_table_id, + (op_alu_imm + op_alu_reg + op_branch) * (shout_lhs - rs1_val), + alu_imm_shift_rhs_delta - shift_selector * (rs2_decode - imm_i), + op_alu_imm * (shout_rhs - imm_i - alu_imm_shift_rhs_delta), + op_alu_reg * (shout_rhs - rs2_val), + op_branch * (shout_rhs - rs2_val), + op_alu_imm_write * (rd_val - shout_val), + op_alu_reg_write * (rd_val - shout_val), + op_alu_reg * (shout_table_id - alu_table_base - alu_reg_table_delta), + op_alu_imm * (shout_table_id - alu_table_base - alu_imm_table_delta), + op_branch * (shout_table_id - branch_table_expected), + op_alu_reg * funct7_bits[0], + alu_reg_table_delta - funct7_bits[5] * (funct3_is[0] + funct3_is[5]), + alu_imm_table_delta - funct7_bits[5] * funct3_is[5], + op_lui * rd_has_write - op_lui_write, + op_auipc * rd_has_write - op_auipc_write, + op_jal * rd_has_write - op_jal_write, + op_jalr * rd_has_write - op_jalr_write, + op_alu_imm * rd_has_write - op_alu_imm_write, + op_alu_reg * rd_has_write - op_alu_reg_write, + op_lui * (rd_has_write + rd_is_zero - K::ONE), + op_auipc * (rd_has_write + rd_is_zero - K::ONE), + op_jal * (rd_has_write + rd_is_zero - K::ONE), + op_jalr * (rd_has_write + rd_is_zero - K::ONE), + opcode_flags[5] * (rd_has_write + rd_is_zero - K::ONE), + op_alu_imm * (rd_has_write + rd_is_zero - K::ONE), + op_alu_reg * (rd_has_write + rd_is_zero - K::ONE), + op_branch * rd_has_write, + opcode_flags[6] * rd_has_write, + op_misc_mem * rd_has_write, + op_system * rd_has_write, + active * (halted - op_system), + opcode_flags[5] * (ram_has_read - K::ONE), + opcode_flags[6] * (ram_has_write - K::ONE), + non_mem_ops * ram_has_read, + non_mem_ops * ram_has_write, + non_mem_ops * ram_addr, + opcode_flags[5] * (ram_addr - rs1_val - imm_i), + opcode_flags[6] * (ram_addr - rs1_val - imm_s), + ] +} + +#[inline] +pub(crate) fn w2_decode_immediate_residuals( + imm_i: K, + imm_s: K, + imm_b: K, + imm_j: K, + rd_bits: [K; 5], + funct3_bits: [K; 3], + rs1_bits: [K; 5], + rs2_bits: [K; 5], + funct7_bits: [K; 7], +) -> [K; 4] { + let signext_imm12 = K::from(F::from_u64((1u64 << 32) - (1u64 << 11))); + let signext_imm13 = K::from(F::from_u64((1u64 << 32) - (1u64 << 12))); + let signext_imm21 = K::from(F::from_u64((1u64 << 32) - (1u64 << 20))); + + let imm_i_res = imm_i + - rs2_bits[0] + - K::from(F::from_u64(2)) * rs2_bits[1] + - K::from(F::from_u64(4)) * rs2_bits[2] + - K::from(F::from_u64(8)) * rs2_bits[3] + - K::from(F::from_u64(16)) * rs2_bits[4] + - K::from(F::from_u64(32)) * funct7_bits[0] + - K::from(F::from_u64(64)) * funct7_bits[1] + - K::from(F::from_u64(128)) * funct7_bits[2] + - K::from(F::from_u64(256)) * funct7_bits[3] + - K::from(F::from_u64(512)) * funct7_bits[4] + - K::from(F::from_u64(1024)) * funct7_bits[5] + - signext_imm12 * funct7_bits[6]; + + let imm_s_res = imm_s + - rd_bits[0] + - K::from(F::from_u64(2)) * rd_bits[1] + - K::from(F::from_u64(4)) * rd_bits[2] + - K::from(F::from_u64(8)) * rd_bits[3] + - K::from(F::from_u64(16)) * rd_bits[4] + - K::from(F::from_u64(32)) * funct7_bits[0] + - K::from(F::from_u64(64)) * funct7_bits[1] + - K::from(F::from_u64(128)) * funct7_bits[2] + - K::from(F::from_u64(256)) * funct7_bits[3] + - K::from(F::from_u64(512)) * funct7_bits[4] + - K::from(F::from_u64(1024)) * funct7_bits[5] + - signext_imm12 * funct7_bits[6]; + + let imm_b_res = imm_b + - K::from(F::from_u64(2)) * rd_bits[1] + - K::from(F::from_u64(4)) * rd_bits[2] + - K::from(F::from_u64(8)) * rd_bits[3] + - K::from(F::from_u64(16)) * rd_bits[4] + - K::from(F::from_u64(32)) * funct7_bits[0] + - K::from(F::from_u64(64)) * funct7_bits[1] + - K::from(F::from_u64(128)) * funct7_bits[2] + - K::from(F::from_u64(256)) * funct7_bits[3] + - K::from(F::from_u64(512)) * funct7_bits[4] + - K::from(F::from_u64(1024)) * funct7_bits[5] + - K::from(F::from_u64(2048)) * rd_bits[0] + - signext_imm13 * funct7_bits[6]; + + let imm_j_res = imm_j + - K::from(F::from_u64(2)) * rs2_bits[1] + - K::from(F::from_u64(4)) * rs2_bits[2] + - K::from(F::from_u64(8)) * rs2_bits[3] + - K::from(F::from_u64(16)) * rs2_bits[4] + - K::from(F::from_u64(32)) * funct7_bits[0] + - K::from(F::from_u64(64)) * funct7_bits[1] + - K::from(F::from_u64(128)) * funct7_bits[2] + - K::from(F::from_u64(256)) * funct7_bits[3] + - K::from(F::from_u64(512)) * funct7_bits[4] + - K::from(F::from_u64(1024)) * funct7_bits[5] + - K::from(F::from_u64(2048)) * rs2_bits[0] + - K::from(F::from_u64(4096)) * funct3_bits[0] + - K::from(F::from_u64(8192)) * funct3_bits[1] + - K::from(F::from_u64(16384)) * funct3_bits[2] + - K::from(F::from_u64(32768)) * rs1_bits[0] + - K::from(F::from_u64(65536)) * rs1_bits[1] + - K::from(F::from_u64(131072)) * rs1_bits[2] + - K::from(F::from_u64(262144)) * rs1_bits[3] + - K::from(F::from_u64(524288)) * rs1_bits[4] + - signext_imm21 * funct7_bits[6]; + + [imm_i_res, imm_s_res, imm_b_res, imm_j_res] +} + +#[inline] +pub(crate) fn w3_load_semantics_residuals( + rd_val: K, + ram_rv: K, + rd_has_write: K, + ram_has_read: K, + load_flags: [K; 5], + ram_rv_q16: K, + ram_rv_low_bits: [K; 16], +) -> [K; 16] { + let pow2 = |k: usize| K::from(F::from_u64(1u64 << k)); + let two16 = K::from(F::from_u64(1u64 << 16)); + let lb_sign_coeff = K::from(F::from_u64((1u64 << 32) - (1u64 << 7))); + let lh_sign_coeff = K::from(F::from_u64((1u64 << 32) - (1u64 << 15))); + + let mut ram_rv_low8 = K::ZERO; + for (k, b) in ram_rv_low_bits.iter().copied().enumerate().take(8) { + ram_rv_low8 += pow2(k) * b; + } + let mut ram_rv_low16 = K::ZERO; + for (k, b) in ram_rv_low_bits.iter().copied().enumerate() { + ram_rv_low16 += pow2(k) * b; + } + + let lb_val = { + let mut acc = K::ZERO; + for (k, b) in ram_rv_low_bits.iter().copied().enumerate().take(8) { + acc += if k == 7 { lb_sign_coeff } else { pow2(k) } * b; + } + acc + }; + let lh_val = { + let mut acc = K::ZERO; + for (k, b) in ram_rv_low_bits.iter().copied().enumerate() { + if k >= 16 { + break; + } + acc += if k == 15 { lh_sign_coeff } else { pow2(k) } * b; + } + acc + }; + + [ + load_flags[4] * (rd_val - ram_rv), + load_flags[0] * (rd_val - lb_val), + load_flags[1] * (rd_val - ram_rv_low8), + load_flags[2] * (rd_val - lh_val), + load_flags[3] * (rd_val - ram_rv_low16), + load_flags[0] * (rd_has_write - K::ONE), + load_flags[1] * (rd_has_write - K::ONE), + load_flags[2] * (rd_has_write - K::ONE), + load_flags[3] * (rd_has_write - K::ONE), + load_flags[4] * (rd_has_write - K::ONE), + load_flags[0] * (ram_has_read - K::ONE), + load_flags[1] * (ram_has_read - K::ONE), + load_flags[2] * (ram_has_read - K::ONE), + load_flags[3] * (ram_has_read - K::ONE), + load_flags[4] * (ram_has_read - K::ONE), + ram_has_read * (ram_rv - two16 * ram_rv_q16 - ram_rv_low16), + ] +} + +#[inline] +pub(crate) fn w3_store_semantics_residuals( + ram_wv: K, + ram_rv: K, + rs2_val: K, + rd_has_write: K, + ram_has_read: K, + ram_has_write: K, + store_flags: [K; 3], + rs2_q16: K, + ram_rv_low_bits: [K; 16], + rs2_low_bits: [K; 16], +) -> [K; 12] { + let pow2 = |k: usize| K::from(F::from_u64(1u64 << k)); + let two16 = K::from(F::from_u64(1u64 << 16)); + let mut rs2_low16 = K::ZERO; + let mut sb_patch = K::ZERO; + let mut sh_patch = K::ZERO; + for k in 0..16 { + let coeff = pow2(k); + rs2_low16 += coeff * rs2_low_bits[k]; + if k < 8 { + sb_patch += coeff * (ram_rv_low_bits[k] - rs2_low_bits[k]); + } + sh_patch += coeff * (ram_rv_low_bits[k] - rs2_low_bits[k]); + } + [ + store_flags[2] * (ram_wv - rs2_val), + store_flags[0] * (ram_wv - ram_rv + sb_patch), + store_flags[1] * (ram_wv - ram_rv + sh_patch), + store_flags[0] * rd_has_write, + store_flags[1] * rd_has_write, + store_flags[2] * rd_has_write, + store_flags[0] * (ram_has_read - K::ONE), + store_flags[1] * (ram_has_read - K::ONE), + store_flags[0] * (ram_has_write - K::ONE), + store_flags[1] * (ram_has_write - K::ONE), + store_flags[2] * (ram_has_write - K::ONE), + rs2_val - two16 * rs2_q16 - rs2_low16, + ] +} + +#[inline] +pub(crate) fn control_branch_taken_from_bits(shout_val: K, funct3_bit0: K) -> K { + shout_val + funct3_bit0 - K::from(F::from_u64(2)) * funct3_bit0 * shout_val +} + +#[inline] +pub(crate) fn control_imm_u_from_bits( + funct3_bits: [K; 3], + rs1_bits: [K; 5], + rs2_bits: [K; 5], + funct7_bits: [K; 7], +) -> K { + let pow2 = |k: u64| K::from(F::from_u64(1u64 << k)); + let mut out = K::ZERO; + out += pow2(12) * funct3_bits[0]; + out += pow2(13) * funct3_bits[1]; + out += pow2(14) * funct3_bits[2]; + out += pow2(15) * rs1_bits[0]; + out += pow2(16) * rs1_bits[1]; + out += pow2(17) * rs1_bits[2]; + out += pow2(18) * rs1_bits[3]; + out += pow2(19) * rs1_bits[4]; + out += pow2(20) * rs2_bits[0]; + out += pow2(21) * rs2_bits[1]; + out += pow2(22) * rs2_bits[2]; + out += pow2(23) * rs2_bits[3]; + out += pow2(24) * rs2_bits[4]; + out += pow2(25) * funct7_bits[0]; + out += pow2(26) * funct7_bits[1]; + out += pow2(27) * funct7_bits[2]; + out += pow2(28) * funct7_bits[3]; + out += pow2(29) * funct7_bits[4]; + out += pow2(30) * funct7_bits[5]; + out += pow2(31) * funct7_bits[6]; + out +} + +#[inline] +pub(crate) fn control_next_pc_linear_residual( + pc_before: K, + pc_after: K, + op_lui: K, + op_auipc: K, + op_load: K, + op_store: K, + op_alu_imm: K, + op_alu_reg: K, + op_misc_mem: K, + op_system: K, + op_amo: K, +) -> K { + let op_linear = op_lui + op_auipc + op_load + op_store + op_alu_imm + op_alu_reg + op_misc_mem + op_system + op_amo; + op_linear * (pc_after - pc_before - K::from(F::from_u64(4))) +} + +#[inline] +pub(crate) fn control_next_pc_control_residuals( + active: K, + pc_before: K, + pc_after: K, + rs1_val: K, + jalr_drop_bit: K, + imm_i: K, + imm_b: K, + imm_j: K, + op_jal: K, + op_jalr: K, + op_branch: K, + shout_val: K, + funct3_bit0: K, +) -> [K; 5] { + let four = K::from(F::from_u64(4)); + let taken = control_branch_taken_from_bits(shout_val, funct3_bit0); + [ + op_jal * (pc_after - pc_before - imm_j), + op_jalr * (pc_after - rs1_val - imm_i + jalr_drop_bit), + op_branch * (pc_after - pc_before - four - taken * (imm_b - four)), + op_jalr * jalr_drop_bit * (jalr_drop_bit - K::ONE), + (active - op_jalr) * jalr_drop_bit, + ] +} + +#[inline] +pub(crate) fn control_branch_semantics_residuals( + op_branch: K, + shout_val: K, + _funct3_bit0: K, + funct3_bit1: K, + funct3_bit2: K, + funct3_is6: K, + funct3_is7: K, +) -> [K; 2] { + [ + op_branch * ((funct3_is6 + funct3_is7) - funct3_bit1 * funct3_bit2), + op_branch * shout_val * (shout_val - K::ONE), + ] +} + +#[inline] +pub(crate) fn control_writeback_residuals( + rd_val: K, + pc_before: K, + imm_u: K, + op_lui_write: K, + op_auipc_write: K, + op_jal_write: K, + op_jalr_write: K, +) -> [K; 4] { + let four = K::from(F::from_u64(4)); + [ + op_lui_write * (rd_val - imm_u), + op_auipc_write * (rd_val - pc_before - imm_u), + op_jal_write * (rd_val - pc_before - four), + op_jalr_write * (rd_val - pc_before - four), + ] +} + +pub(crate) fn rv32_trace_wp_columns(layout: &Rv32TraceLayout) -> Vec { + vec![ + layout.instr_word, + layout.rs1_addr, + layout.rs1_val, + layout.rs2_addr, + layout.rs2_val, + layout.rd_addr, + layout.rd_val, + layout.ram_addr, + layout.ram_rv, + layout.ram_wv, + layout.shout_has_lookup, + layout.shout_val, + layout.shout_lhs, + layout.shout_rhs, + layout.jalr_drop_bit, + ] +} + +pub(crate) fn rv32_trace_wp_opening_columns(layout: &Rv32TraceLayout) -> Vec { + let mut out = Vec::with_capacity(1 + layout.cols); + out.push(layout.active); + out.extend(rv32_trace_wp_columns(layout)); + out +} + +pub(crate) fn rv32_trace_control_extra_opening_columns(layout: &Rv32TraceLayout) -> Vec { + vec![layout.pc_before, layout.pc_after] +} + +pub(crate) fn infer_rv32_trace_t_len_for_wb_wp( + step: &StepWitnessBundle, + trace: &Rv32TraceLayout, +) -> Result { + if let Some((inst, _)) = step.mem_instances.first() { + return Ok(inst.steps); + } + if let Some((inst, _)) = step.lut_instances.first() { + return Ok(inst.steps); + } + + let m_in = step.mcs.0.m_in; + let m = step.mcs.1.Z.cols(); + let w = m + .checked_sub(m_in) + .ok_or_else(|| PiCcsError::InvalidInput("trace width underflow while inferring t_len".into()))?; + if trace.cols == 0 || w % trace.cols != 0 { + return Err(PiCcsError::InvalidInput( + "cannot infer RV32 trace t_len for WB/WP (missing mem/lut instances and non-divisible witness width)" + .into(), + )); + } + let t_len = w / trace.cols; + if t_len == 0 { + return Err(PiCcsError::InvalidInput( + "RV32 trace t_len must be >= 1 for WB/WP".into(), + )); + } + Ok(t_len) +} + +pub(crate) fn decode_trace_col_values_batch( + params: &NeoParams, + step: &StepWitnessBundle, + t_len: usize, + col_ids: &[usize], +) -> Result>, PiCcsError> { + let m_in = step.mcs.0.m_in; + let m = step.mcs.1.Z.cols(); + let d = neo_math::D; + let z = &step.mcs.1.Z; + if z.rows() != d { + return Err(PiCcsError::InvalidInput(format!( + "WB/WP: CPU witness Z.rows()={} != D={d}", + z.rows() + ))); + } + + let trace_base = m_in; + let b_k = K::from(F::from_u64(params.b as u64)); + let mut pow_b = Vec::with_capacity(d); + let mut cur = K::ONE; + for _ in 0..d { + pow_b.push(cur); + cur *= b_k; + } + + let unique_col_ids: BTreeSet = col_ids.iter().copied().collect(); + let mut decoded = BTreeMap::>::new(); + for col_id in unique_col_ids { + let col_start = trace_base + .checked_add( + col_id + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("WB/WP: col_id * t_len overflow".into()))?, + ) + .ok_or_else(|| PiCcsError::InvalidInput("WB/WP: trace column start overflow".into()))?; + + let mut out = Vec::with_capacity(t_len); + for j in 0..t_len { + let idx = col_start + .checked_add(j) + .ok_or_else(|| PiCcsError::InvalidInput("WB/WP: trace z idx overflow".into()))?; + if idx >= m { + return Err(PiCcsError::InvalidInput(format!( + "WB/WP: trace z idx out of range (idx={idx}, m={m})" + ))); + } + let mut acc = K::ZERO; + for rho in 0..d { + acc += pow_b[rho] * K::from(z[(rho, idx)]); + } + out.push(acc); + } + decoded.insert(col_id, out); + } + + Ok(decoded) +} + +pub(crate) fn decode_lookup_backed_col_values_batch( + params: &NeoParams, + m_in: usize, + t_len: usize, + z: &neo_ccs::matrix::Mat, + max_cols: usize, + col_ids: &[usize], +) -> Result>, PiCcsError> { + let m = z.cols(); + let d = neo_math::D; + if z.rows() != d { + return Err(PiCcsError::InvalidInput(format!( + "W2: decode lookup-backed Z.rows()={} != D={d}", + z.rows() + ))); + } + + let b_k = K::from(F::from_u64(params.b as u64)); + let mut pow_b = Vec::with_capacity(d); + let mut cur = K::ONE; + for _ in 0..d { + pow_b.push(cur); + cur *= b_k; + } + + let unique_col_ids: BTreeSet = col_ids.iter().copied().collect(); + let mut decoded = BTreeMap::>::new(); + for col_id in unique_col_ids { + if col_id >= max_cols { + return Err(PiCcsError::InvalidInput(format!( + "W2: decode lookup-backed column out of range (col_id={col_id}, cols={max_cols})" + ))); + } + let col_start = m_in + .checked_add( + col_id + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("W2: col_id * t_len overflow".into()))?, + ) + .ok_or_else(|| PiCcsError::InvalidInput("W2: trace column start overflow".into()))?; + let mut out = Vec::with_capacity(t_len); + for j in 0..t_len { + let idx = col_start + .checked_add(j) + .ok_or_else(|| PiCcsError::InvalidInput("W2: trace z idx overflow".into()))?; + if idx >= m { + return Err(PiCcsError::InvalidInput(format!( + "W2: decode lookup-backed z idx out of range (idx={idx}, m={m})" + ))); + } + let mut acc = K::ZERO; + for rho in 0..d { + acc += pow_b[rho] * K::from(z[(rho, idx)]); + } + out.push(acc); + } + decoded.insert(col_id, out); + } + Ok(decoded) +} diff --git a/crates/neo-fold/src/memory_sidecar/route_a_time.rs b/crates/neo-fold/src/memory_sidecar/route_a_time.rs index ee39a129..818b445a 100644 --- a/crates/neo-fold/src/memory_sidecar/route_a_time.rs +++ b/crates/neo-fold/src/memory_sidecar/route_a_time.rs @@ -36,6 +36,19 @@ pub fn prove_route_a_batched_time( step: &StepWitnessBundle, twist_read_claims: Vec, twist_write_claims: Vec, + wb_time_claim: Option, + wp_time_claim: Option, + decode_decode_fields_claim: Option, + decode_decode_immediates_claim: Option, + width_bitness_claim: Option, + width_quiescence_claim: Option, + width_selector_linkage_claim: Option, + width_load_semantics_claim: Option, + width_store_semantics_claim: Option, + control_next_pc_linear_claim: Option, + control_next_pc_control_claim: Option, + control_branch_semantics_claim: Option, + control_control_writeback_claim: Option, ob_inc_total: Option, ) -> Result { let mut claimed_sums: Vec = Vec::new(); @@ -57,7 +70,8 @@ pub fn prove_route_a_batched_time( label: b"ccs/time", }); - let mut shout_protocol = ShoutRouteAProtocol::new(&mut mem_oracles.shout, ell_n); + let mut shout_protocol = + ShoutRouteAProtocol::new(&mut mem_oracles.shout, &mut mem_oracles.shout_gamma_groups, ell_n); shout_protocol.append_time_claims( ell_n, &mut claimed_sums, @@ -67,6 +81,24 @@ pub fn prove_route_a_batched_time( &mut claims, ); + // Optional: event-table Shout linkage trace hash claim. + let shout_event_trace_hash_claim = mem_oracles.shout_event_trace_hash.as_ref().map(|o| o.claim); + let mut shout_event_trace_hash_prefix = mem_oracles + .shout_event_trace_hash + .as_mut() + .map(|o| RoundOraclePrefix::new(o.oracle.as_mut(), ell_n)); + if let (Some(claim), Some(prefix)) = (shout_event_trace_hash_claim, shout_event_trace_hash_prefix.as_mut()) { + claimed_sums.push(claim); + degree_bounds.push(prefix.degree_bound()); + labels.push(b"shout/event_trace_hash"); + claim_is_dynamic.push(true); + claims.push(BatchedClaim { + oracle: prefix, + claimed_sum: claim, + label: b"shout/event_trace_hash", + }); + } + let mut twist_protocol = TwistRouteAProtocol::new(&mut mem_oracles.twist, ell_n, twist_read_claims, twist_write_claims); twist_protocol.append_time_claims( @@ -78,6 +110,297 @@ pub fn prove_route_a_batched_time( &mut claims, ); + let wb_time_degree_bound = wb_time_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut wb_time_label: Option<&'static [u8]> = None; + let mut wb_time_oracle: Option> = wb_time_claim.map(|extra| { + wb_time_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = wb_time_oracle.as_deref_mut() { + // WB is a zero-identity stage: claimed sum is verifier-known and fixed to zero. + let claimed_sum = K::ZERO; + let label = wb_time_label.expect("missing wb_time label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let wp_time_degree_bound = wp_time_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut wp_time_label: Option<&'static [u8]> = None; + let mut wp_time_oracle: Option> = wp_time_claim.map(|extra| { + wp_time_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = wp_time_oracle.as_deref_mut() { + // WP is a zero-identity stage: claimed sum is verifier-known and fixed to zero. + let claimed_sum = K::ZERO; + let label = wp_time_label.expect("missing wp_time label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let decode_decode_fields_degree_bound = decode_decode_fields_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut decode_decode_fields_label: Option<&'static [u8]> = None; + let mut decode_decode_fields_oracle: Option> = decode_decode_fields_claim.map(|extra| { + decode_decode_fields_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = decode_decode_fields_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = decode_decode_fields_label.expect("missing decode_fields label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let decode_decode_immediates_degree_bound = decode_decode_immediates_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut decode_decode_immediates_label: Option<&'static [u8]> = None; + let mut decode_decode_immediates_oracle: Option> = + decode_decode_immediates_claim.map(|extra| { + decode_decode_immediates_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = decode_decode_immediates_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = decode_decode_immediates_label.expect("missing decode_immediates label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let width_bitness_degree_bound = width_bitness_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut width_bitness_label: Option<&'static [u8]> = None; + let mut width_bitness_oracle: Option> = width_bitness_claim.map(|extra| { + width_bitness_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = width_bitness_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = width_bitness_label.expect("missing width_bitness label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let width_quiescence_degree_bound = width_quiescence_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut width_quiescence_label: Option<&'static [u8]> = None; + let mut width_quiescence_oracle: Option> = width_quiescence_claim.map(|extra| { + width_quiescence_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = width_quiescence_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = width_quiescence_label.expect("missing width_quiescence label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let width_selector_linkage_degree_bound = width_selector_linkage_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut width_selector_linkage_label: Option<&'static [u8]> = None; + let mut width_selector_linkage_oracle: Option> = width_selector_linkage_claim.map(|extra| { + width_selector_linkage_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = width_selector_linkage_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = width_selector_linkage_label.expect("missing width_selector_linkage label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let width_load_semantics_degree_bound = width_load_semantics_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut width_load_semantics_label: Option<&'static [u8]> = None; + let mut width_load_semantics_oracle: Option> = width_load_semantics_claim.map(|extra| { + width_load_semantics_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = width_load_semantics_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = width_load_semantics_label.expect("missing width_load_semantics label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let width_store_semantics_degree_bound = width_store_semantics_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut width_store_semantics_label: Option<&'static [u8]> = None; + let mut width_store_semantics_oracle: Option> = width_store_semantics_claim.map(|extra| { + width_store_semantics_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = width_store_semantics_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = width_store_semantics_label.expect("missing width_store_semantics label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let control_next_pc_linear_degree_bound = control_next_pc_linear_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut control_next_pc_linear_label: Option<&'static [u8]> = None; + let mut control_next_pc_linear_oracle: Option> = control_next_pc_linear_claim.map(|extra| { + control_next_pc_linear_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = control_next_pc_linear_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = control_next_pc_linear_label.expect("missing control_next_pc_linear label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let control_next_pc_control_degree_bound = control_next_pc_control_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut control_next_pc_control_label: Option<&'static [u8]> = None; + let mut control_next_pc_control_oracle: Option> = control_next_pc_control_claim.map(|extra| { + control_next_pc_control_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = control_next_pc_control_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = control_next_pc_control_label.expect("missing control_next_pc_control label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let control_branch_semantics_degree_bound = control_branch_semantics_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut control_branch_semantics_label: Option<&'static [u8]> = None; + let mut control_branch_semantics_oracle: Option> = + control_branch_semantics_claim.map(|extra| { + control_branch_semantics_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = control_branch_semantics_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = control_branch_semantics_label.expect("missing control_branch_semantics label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + + let control_control_writeback_degree_bound = control_control_writeback_claim + .as_ref() + .map(|extra| extra.oracle.degree_bound()); + let mut control_control_writeback_label: Option<&'static [u8]> = None; + let mut control_control_writeback_oracle: Option> = + control_control_writeback_claim.map(|extra| { + control_control_writeback_label = Some(extra.label); + extra.oracle + }); + if let Some(oracle) = control_control_writeback_oracle.as_deref_mut() { + let claimed_sum = K::ZERO; + let label = control_control_writeback_label.expect("missing control_writeback label"); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(false); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } + let ob_inc_total_degree_bound = ob_inc_total .as_ref() .map(|extra| extra.oracle.degree_bound()); @@ -106,6 +429,18 @@ pub fn prove_route_a_batched_time( step.lut_instances.iter().map(|(inst, _)| inst), step.mem_instances.iter().map(|(inst, _)| inst), ccs_time_degree_bound, + wb_time_degree_bound.is_some(), + wp_time_degree_bound.is_some(), + decode_decode_fields_degree_bound.is_some() || decode_decode_immediates_degree_bound.is_some(), + width_bitness_degree_bound.is_some() + || width_quiescence_degree_bound.is_some() + || width_selector_linkage_degree_bound.is_some() + || width_load_semantics_degree_bound.is_some() + || width_store_semantics_degree_bound.is_some(), + control_next_pc_linear_degree_bound.is_some() + || control_next_pc_control_degree_bound.is_some() + || control_branch_semantics_degree_bound.is_some() + || control_control_writeback_degree_bound.is_some(), ob_inc_total_degree_bound, ); let expected_degree_bounds: Vec = metas.iter().map(|m| m.degree_bound).collect(); @@ -164,9 +499,23 @@ pub fn verify_route_a_batched_time( claimed_initial_sum: K, step: &StepInstanceBundle, proof: &BatchedTimeProof, + wb_enabled: bool, + wp_enabled: bool, + decode_stage_enabled: bool, + width_stage_enabled: bool, + control_stage_enabled: bool, ob_inc_total_degree_bound: Option, ) -> Result { - let metas = RouteATimeClaimPlan::time_claim_metas_for_step(step, ccs_time_degree_bound, ob_inc_total_degree_bound); + let metas = RouteATimeClaimPlan::time_claim_metas_for_step( + step, + ccs_time_degree_bound, + wb_enabled, + wp_enabled, + decode_stage_enabled, + width_stage_enabled, + control_stage_enabled, + ob_inc_total_degree_bound, + ); let expected_degree_bounds: Vec = metas.iter().map(|m| m.degree_bound).collect(); let expected_labels: Vec<&'static [u8]> = metas.iter().map(|m| m.label).collect(); let claim_is_dynamic: Vec = metas.iter().map(|m| m.is_dynamic).collect(); diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs deleted file mode 100644 index 25c14283..00000000 --- a/crates/neo-fold/src/riscv_shard.rs +++ /dev/null @@ -1,741 +0,0 @@ -//! Convenience wrappers for verifying RISC-V shard proofs safely. -//! -//! These helpers are intentionally small: they standardize the step-linking configuration -//! for RV32 B1 chunked execution so callers don't accidentally verify a "bag of chunks". - -#![allow(non_snake_case)] - -use std::collections::HashMap; -use std::collections::HashSet; -use std::time::Duration; - -use crate::output_binding::{simple_output_config, OutputBindingConfig}; -use crate::pi_ccs::FoldingMode; -use crate::session::FoldingSession; -use crate::shard::{ - fold_shard_verify_with_output_binding_and_step_linking, fold_shard_verify_with_step_linking, CommitMixers, - ShardFoldOutputs, ShardProof, StepLinkingConfig, -}; -use crate::PiCcsError; -use neo_ajtai::{AjtaiSModule, Commitment as Cmt}; -use neo_ccs::{CcsStructure, Mat, MeInstance}; -use neo_math::{F, K}; -use neo_memory::mem_init_from_initial_mem; -use neo_memory::output_check::ProgramIO; -use neo_memory::plain::LutTable; -use neo_memory::plain::PlainMemLayout; -use neo_memory::riscv::ccs::{ - build_rv32_b1_step_ccs, rv32_b1_chunk_to_witness, rv32_b1_shared_cpu_bus_config, rv32_b1_step_linking_pairs, - Rv32B1Layout, -}; -use neo_memory::riscv::lookups::{decode_program, RiscvCpu, RiscvInstruction, RiscvOpcode, RiscvShoutTables, PROG_ID}; -use neo_memory::riscv::shard::{extract_boundary_state, Rv32BoundaryState}; -use neo_memory::witness::LutTableSpec; -use neo_memory::witness::{StepInstanceBundle, StepWitnessBundle}; -use neo_memory::R1csCpu; -use neo_params::NeoParams; -use neo_transcript::Poseidon2Transcript; -use neo_vm_trace::Twist as _; -use p3_field::PrimeCharacteristicRing; - -#[cfg(target_arch = "wasm32")] -use js_sys::Date; -#[cfg(not(target_arch = "wasm32"))] -use std::time::Instant; - -#[cfg(target_arch = "wasm32")] -type TimePoint = f64; -#[cfg(not(target_arch = "wasm32"))] -type TimePoint = Instant; - -#[inline] -fn time_now() -> TimePoint { - #[cfg(target_arch = "wasm32")] - { - Date::now() - } - #[cfg(not(target_arch = "wasm32"))] - { - Instant::now() - } -} - -#[inline] -fn elapsed_duration(start: TimePoint) -> Duration { - #[cfg(target_arch = "wasm32")] - { - let elapsed_ms = Date::now() - start; - Duration::from_secs_f64(elapsed_ms / 1_000.0) - } - #[cfg(not(target_arch = "wasm32"))] - { - start.elapsed() - } -} - -pub fn rv32_b1_step_linking_config(layout: &Rv32B1Layout) -> StepLinkingConfig { - StepLinkingConfig::new(rv32_b1_step_linking_pairs(layout)) -} - -/// Enforce that the *public statement* initial memory matches chunk 0's `MemInstance.init`. -/// -/// This lets later chunk `init` snapshots remain proof-internal rollover data (Twist needs them), -/// while keeping the user-facing statement independent of `chunk_size`. -pub fn rv32_b1_enforce_chunk0_mem_init_matches_statement( - mem_layouts: &HashMap, - statement_initial_mem: &HashMap<(u32, u64), F>, - steps: &[StepInstanceBundle], -) -> Result<(), PiCcsError> { - let chunk0 = steps - .first() - .ok_or_else(|| PiCcsError::InvalidInput("no steps provided".into()))?; - - let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); - mem_ids.sort_unstable(); - - if chunk0.mem_insts.len() != mem_ids.len() { - return Err(PiCcsError::InvalidInput(format!( - "mem instance count mismatch: chunk0 has {}, but mem_layouts has {}", - chunk0.mem_insts.len(), - mem_ids.len() - ))); - } - - for (idx, mem_id) in mem_ids.into_iter().enumerate() { - let layout = mem_layouts - .get(&mem_id) - .ok_or_else(|| PiCcsError::InvalidInput(format!("missing PlainMemLayout for mem_id={mem_id}")))?; - let expected = mem_init_from_initial_mem(mem_id, layout.k, statement_initial_mem)?; - let got = &chunk0.mem_insts[idx].init; - if *got != expected { - return Err(PiCcsError::InvalidInput(format!( - "chunk0 MemInstance.init mismatch for mem_id={mem_id}" - ))); - } - } - - Ok(()) -} - -pub fn fold_shard_verify_rv32_b1( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - layout: &Rv32B1Layout, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let step_linking = rv32_b1_step_linking_config(layout); - fold_shard_verify_with_step_linking(mode, tr, params, s_me, steps, acc_init, proof, mixers, &step_linking) -} - -pub fn fold_shard_verify_rv32_b1_with_statement_mem_init( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - mem_layouts: &HashMap, - statement_initial_mem: &HashMap<(u32, u64), F>, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - layout: &Rv32B1Layout, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - rv32_b1_enforce_chunk0_mem_init_matches_statement(mem_layouts, statement_initial_mem, steps)?; - fold_shard_verify_rv32_b1(mode, tr, params, s_me, steps, acc_init, proof, mixers, layout) -} - -pub fn fold_shard_verify_rv32_b1_with_output_binding( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, - layout: &Rv32B1Layout, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let step_linking = rv32_b1_step_linking_config(layout); - fold_shard_verify_with_output_binding_and_step_linking( - mode, - tr, - params, - s_me, - steps, - acc_init, - proof, - mixers, - ob_cfg, - &step_linking, - ) -} - -fn pow2_ceil_k(min_k: usize) -> (usize, usize) { - // RV32 B1 alignment constraints require bit-addressed memories with d>=2. - let k = min_k.next_power_of_two().max(4); - let d = k.trailing_zeros() as usize; - (k, d) -} - -fn infer_required_shout_opcodes(program: &[RiscvInstruction]) -> HashSet { - let mut ops = HashSet::new(); - - // The ADD table is required because the step circuit uses it for address/PC wiring in multiple - // instructions (LW/SW/AUIPC/JALR), even if the program has no explicit ADD/ADDI. - ops.insert(RiscvOpcode::Add); - - for instr in program { - match instr { - RiscvInstruction::RAlu { op, .. } => { - match op { - // RV32 B1 proves RV32M MUL* in-circuit (no Shout table required). - RiscvOpcode::Mul | RiscvOpcode::Mulh | RiscvOpcode::Mulhu | RiscvOpcode::Mulhsu => {} - // RV32 B1 proves RV32M DIV*/REM* in-circuit, but it requires a SLTU lookup to prove - // the remainder bound when divisor != 0 (unsigned and signed). - RiscvOpcode::Div | RiscvOpcode::Divu | RiscvOpcode::Rem | RiscvOpcode::Remu => { - ops.insert(RiscvOpcode::Sltu); - } - _ => { - ops.insert(*op); - } - } - } - RiscvInstruction::IAlu { op, .. } => { - ops.insert(*op); - } - RiscvInstruction::Branch { cond, .. } => { - ops.insert(cond.to_shout_opcode()); - } - RiscvInstruction::Load { .. } => { - ops.insert(RiscvOpcode::Add); - } - RiscvInstruction::Store { .. } => { - ops.insert(RiscvOpcode::Add); - } - RiscvInstruction::Jalr { .. } => { - ops.insert(RiscvOpcode::Add); - } - RiscvInstruction::Auipc { .. } => { - ops.insert(RiscvOpcode::Add); - } - RiscvInstruction::Amo { op, .. } => match op { - neo_memory::riscv::lookups::RiscvMemOp::AmoaddW | neo_memory::riscv::lookups::RiscvMemOp::AmoaddD => { - ops.insert(RiscvOpcode::Add); - } - neo_memory::riscv::lookups::RiscvMemOp::AmoxorW | neo_memory::riscv::lookups::RiscvMemOp::AmoxorD => { - ops.insert(RiscvOpcode::Xor); - } - neo_memory::riscv::lookups::RiscvMemOp::AmoandW | neo_memory::riscv::lookups::RiscvMemOp::AmoandD => { - ops.insert(RiscvOpcode::And); - } - neo_memory::riscv::lookups::RiscvMemOp::AmoorW | neo_memory::riscv::lookups::RiscvMemOp::AmoorD => { - ops.insert(RiscvOpcode::Or); - } - _ => {} - }, - _ => {} - } - } - - ops -} - -fn all_shout_opcodes() -> HashSet { - use RiscvOpcode::*; - // RV32 B1 uses implicit Shout tables only for opcodes with a closed-form MLE implementation. - // RV32M ops are proven in-circuit (MUL/DIVU/REMU) or not yet supported (MULH/DIV/REM, etc.). - HashSet::from([And, Xor, Or, Sub, Add, Sltu, Slt, Eq, Neq, Sll, Srl, Sra]) -} - -/// High-level “few lines” builder for proving/verifying an RV32 program using the B1 shared-bus step circuit. -/// -/// This: -/// - chooses parameters + Ajtai committer automatically, -/// - infers the minimal Shout table set from the program (unless overridden), -/// - enforces RV32 B1 step linking, and -/// - (optionally) proves output binding against RAM. -#[derive(Clone, Debug)] -pub struct Rv32B1 { - program_base: u64, - program_bytes: Vec, - xlen: usize, - ram_bytes: usize, - chunk_size: usize, - max_steps: Option, - mode: FoldingMode, - shout_auto_minimal: bool, - shout_ops: Option>, - output_claims: ProgramIO, - ram_init: HashMap, -} - -/// Default instruction cap for RV32B1 runs when `max_steps` is not specified. -/// -/// The runner stops early if the guest halts (e.g. via `ecall`), so this is only a safety bound -/// against non-halting guests. -const DEFAULT_RV32B1_MAX_STEPS: usize = 1 << 20; - -impl Rv32B1 { - /// Create a runner from ROM bytes (must be a valid RV32 program encoding). - pub fn from_rom(program_base: u64, program_bytes: &[u8]) -> Self { - Self { - program_base, - program_bytes: program_bytes.to_vec(), - xlen: 32, - ram_bytes: 0x200, - chunk_size: 1, - max_steps: None, - mode: FoldingMode::Optimized, - shout_auto_minimal: true, - shout_ops: None, - output_claims: ProgramIO::new(), - ram_init: HashMap::new(), - } - } - - pub fn xlen(mut self, xlen: usize) -> Self { - self.xlen = xlen; - self - } - - pub fn ram_bytes(mut self, ram_bytes: usize) -> Self { - self.ram_bytes = ram_bytes; - self - } - - pub fn chunk_size(mut self, chunk_size: usize) -> Self { - self.chunk_size = chunk_size; - self - } - - /// Limit the number of instructions executed from the decoded program. - /// - /// This is primarily for tests/benchmarks that want a tiny trace, or for non-halting guests - /// where you want to prove only a prefix of execution. - pub fn max_steps(mut self, max_steps: usize) -> Self { - self.max_steps = Some(max_steps); - self - } - - pub fn mode(mut self, mode: FoldingMode) -> Self { - self.mode = mode; - self - } - - pub fn shout_auto_minimal(mut self) -> Self { - self.shout_ops = None; - self.shout_auto_minimal = true; - self - } - - pub fn shout_all(mut self) -> Self { - self.shout_ops = None; - self.shout_auto_minimal = false; - self - } - - pub fn shout_ops(mut self, ops: impl IntoIterator) -> Self { - self.shout_ops = Some(ops.into_iter().collect()); - self.shout_auto_minimal = false; - self - } - - pub fn output(mut self, output_addr: u64, expected_output: F) -> Self { - self.output_claims = ProgramIO::new().with_output(output_addr, expected_output); - self - } - - pub fn output_claim(mut self, addr: u64, value: F) -> Self { - self.output_claims = self.output_claims.with_output(addr, value); - self - } - - pub fn ram_init_u32(mut self, addr: u64, value: u32) -> Self { - self.ram_init.insert(addr, value as u64); - self - } - - pub fn prove(self) -> Result { - if self.xlen != 32 { - return Err(PiCcsError::InvalidInput(format!( - "RV32 B1 MVP requires xlen == 32 (got {})", - self.xlen - ))); - } - if self.program_bytes.is_empty() { - return Err(PiCcsError::InvalidInput("program_bytes must be non-empty".into())); - } - if self.chunk_size == 0 { - return Err(PiCcsError::InvalidInput("chunk_size must be non-zero".into())); - } - if self.ram_bytes == 0 { - return Err(PiCcsError::InvalidInput("ram_bytes must be non-zero".into())); - } - if self.program_base != 0 { - return Err(PiCcsError::InvalidInput( - "RV32 B1 MVP requires program_base == 0 (addresses are indices into PROG/RAM layouts)".into(), - )); - } - if self.program_bytes.len() % 4 != 0 { - return Err(PiCcsError::InvalidInput( - "program_bytes must be 4-byte aligned (RV32 B1 runner does not support RVC)".into(), - )); - } - for (i, chunk) in self.program_bytes.chunks_exact(4).enumerate() { - let first_half = u16::from_le_bytes([chunk[0], chunk[1]]); - if (first_half & 0b11) != 0b11 { - return Err(PiCcsError::InvalidInput(format!( - "RV32 B1 runner does not support compressed instructions (RVC): found compressed encoding at word index {i}" - ))); - } - } - - let program = decode_program(&self.program_bytes) - .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; - let using_default_max_steps = self.max_steps.is_none(); - let max_steps = match self.max_steps { - Some(n) => { - if n == 0 { - return Err(PiCcsError::InvalidInput("max_steps must be non-zero".into())); - } - n - } - None => DEFAULT_RV32B1_MAX_STEPS.max(program.len()), - }; - let mut twist = neo_memory::riscv::lookups::RiscvMemory::with_program_in_twist( - self.xlen, - PROG_ID, - /*base_addr=*/ 0, - &self.program_bytes, - ); - let shout = RiscvShoutTables::new(self.xlen); - - let (prog_layout, initial_mem) = neo_memory::riscv::rom_init::prog_rom_layout_and_init_words( - PROG_ID, - /*base_addr=*/ 0, - &self.program_bytes, - ) - .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; - let mut initial_mem = initial_mem; - for (addr, value) in self.ram_init { - let value = value as u32 as u64; - initial_mem.insert((neo_memory::riscv::lookups::RAM_ID.0, addr), F::from_u64(value)); - twist.store(neo_memory::riscv::lookups::RAM_ID, addr, value); - } - - let (k_ram, d_ram) = pow2_ceil_k(self.ram_bytes.max(4)); - let mem_layouts = HashMap::from([ - ( - neo_memory::riscv::lookups::RAM_ID.0, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - (PROG_ID.0, prog_layout), - ]); - - // Shout tables (either inferred, all, or explicitly provided). - let mut shout_ops = match &self.shout_ops { - Some(ops) => ops.clone(), - None if self.shout_auto_minimal => infer_required_shout_opcodes(&program), - None => all_shout_opcodes(), - }; - // The ADD table is required even for programs without explicit ADD/ADDI due to address/PC wiring. - shout_ops.insert(RiscvOpcode::Add); - - let mut table_specs: HashMap = HashMap::new(); - for op in shout_ops { - let table_id = shout.opcode_to_id(op).0; - table_specs.insert( - table_id, - LutTableSpec::RiscvOpcode { - opcode: op, - xlen: self.xlen, - }, - ); - } - let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); - shout_table_ids.sort_unstable(); - - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) - .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; - - // Session + Ajtai committer + params (auto-picked for this CCS). - let mut session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; - let params = session.params().clone(); - let committer = session.committer().clone(); - - let mut vm = RiscvCpu::new(self.xlen); - vm.load_program(/*base=*/ 0, program); - - let empty_tables: HashMap> = HashMap::new(); - let lut_lanes: HashMap = HashMap::new(); - - // CPU arithmetization (builds chunk witnesses and commits them). - let mut cpu = R1csCpu::new( - ccs_base, - params, - committer, - layout.m_in, - &empty_tables, - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ); - cpu = cpu - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, - self.chunk_size, - ) - .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; - - // Always enforce step-to-step chunk chaining for RV32 B1. - session.set_step_linking(rv32_b1_step_linking_config(&layout)); - - // Execute + collect step bundles (and aux for output binding). - session.execute_shard_shared_cpu_bus( - vm, - twist, - shout, - /*max_steps=*/ max_steps, - self.chunk_size, - &mem_layouts, - &empty_tables, - &table_specs, - &lut_lanes, - &initial_mem, - &cpu, - )?; - if using_default_max_steps { - let aux = session - .shared_bus_aux() - .ok_or_else(|| PiCcsError::InvalidInput("missing shared-bus aux (halt status unavailable)".into()))?; - if !aux.did_halt { - return Err(PiCcsError::InvalidInput(format!( - "RV32 execution did not halt within max_steps={max_steps}; call .max_steps(...) to raise the limit or ensure the guest halts (e.g. via ecall)" - ))); - } - } - - // Enforce that the *statement* initial memory matches chunk 0's public MemInit. - let steps_public = session.steps_public(); - rv32_b1_enforce_chunk0_mem_init_matches_statement(&mem_layouts, &initial_mem, &steps_public)?; - - let ccs = cpu.ccs.clone(); - - // Prove phase (timed) - let prove_start = time_now(); - let proof = if self.output_claims.is_empty() { - session.fold_and_prove(&ccs)? - } else { - let ob_cfg = OutputBindingConfig::new(d_ram, self.output_claims.clone()); - session.fold_and_prove_with_output_binding_auto_simple(&ccs, &ob_cfg)? - }; - let prove_duration = elapsed_duration(prove_start); - - Ok(Rv32B1Run { - session, - proof, - ccs, - layout, - mem_layouts, - initial_mem, - ram_num_bits: d_ram, - output_claims: self.output_claims, - prove_duration, - verify_duration: None, - }) - } -} - -pub struct Rv32B1Run { - session: FoldingSession, - proof: ShardProof, - ccs: CcsStructure, - layout: Rv32B1Layout, - mem_layouts: HashMap, - initial_mem: HashMap<(u32, u64), F>, - ram_num_bits: usize, - output_claims: ProgramIO, - prove_duration: Duration, - verify_duration: Option, -} - -impl Rv32B1Run { - pub fn params(&self) -> &NeoParams { - self.session.params() - } - - pub fn ccs(&self) -> &CcsStructure { - &self.ccs - } - - pub fn verify(&mut self) -> Result<(), PiCcsError> { - let verify_start = time_now(); - let ok = if self.output_claims.is_empty() { - self.session.verify_collected(&self.ccs, &self.proof)? - } else { - let ob_cfg = OutputBindingConfig::new(self.ram_num_bits, self.output_claims.clone()); - self.session - .verify_with_output_binding_collected_simple(&self.ccs, &self.proof, &ob_cfg)? - }; - self.verify_duration = Some(elapsed_duration(verify_start)); - - if !ok { - return Err(PiCcsError::ProtocolError("verification failed".into())); - } - Ok(()) - } - - pub fn proof(&self) -> &ShardProof { - &self.proof - } - - /// Access the collected per-step witness bundles (includes private witness). - /// - /// This is intended for debugging/profiling and for tests that want to inspect witness shapes. - pub fn steps_witness(&self) -> &[StepWitnessBundle] { - self.session.steps_witness() - } - - pub fn steps_public(&self) -> Vec> { - self.session.steps_public() - } - - pub fn final_boundary_state(&self) -> Result { - let steps_public = self.steps_public(); - let last = steps_public - .last() - .ok_or_else(|| PiCcsError::InvalidInput("no steps collected".into()))?; - extract_boundary_state(&self.layout, &last.mcs_inst.x) - .map_err(|e| PiCcsError::InvalidInput(format!("extract_boundary_state failed: {e}"))) - } - - pub fn verify_output_claim(&self, output_addr: u64, expected_output: F) -> Result { - let ob_cfg = simple_output_config(self.ram_num_bits, output_addr, expected_output); - self.session - .verify_with_output_binding_collected_simple(&self.ccs, &self.proof, &ob_cfg) - } - - pub fn verify_default_output_claim(&self) -> Result { - if self.output_claims.is_empty() { - return Err(PiCcsError::InvalidInput("no output claim configured".into())); - }; - let ob_cfg = OutputBindingConfig::new(self.ram_num_bits, self.output_claims.clone()); - self.session - .verify_with_output_binding_collected_simple(&self.ccs, &self.proof, &ob_cfg) - } - - pub fn verify_output_claims(&self, output_claims: ProgramIO) -> Result { - if output_claims.is_empty() { - return Err(PiCcsError::InvalidInput("output_claims must be non-empty".into())); - } - let ob_cfg = OutputBindingConfig::new(self.ram_num_bits, output_claims); - self.session - .verify_with_output_binding_collected_simple(&self.ccs, &self.proof, &ob_cfg) - } - - /// Original unpadded RV32 trace length (instruction count), if this run was built via shared-bus execution. - pub fn riscv_trace_len(&self) -> Result { - let aux = self - .session - .shared_bus_aux() - .ok_or_else(|| PiCcsError::InvalidInput("missing shared-bus aux (trace length unavailable)".into()))?; - Ok(aux.original_len) - } - - /// CCS constraint count (rows). For RV32 B1 this is the size of the per-chunk step circuit. - pub fn ccs_num_constraints(&self) -> usize { - self.ccs.n - } - - /// CCS variable count (cols). For RV32 B1 this is the number of witness variables per chunk. - pub fn ccs_num_variables(&self) -> usize { - self.ccs.m - } - - /// Number of folding steps proven (one per collected chunk). - pub fn fold_count(&self) -> usize { - self.proof.steps.len() - } - - /// Count the number of Shout lookups actually used across the executed trace (active rows only). - pub fn shout_lookup_count(&self) -> Result { - let mut count = 0usize; - for step in self.session.steps_witness() { - let x = &step.mcs.0.x; - let w = &step.mcs.1.w; - let m_in = step.mcs.0.m_in; - - let z_at = |idx: usize| -> Result { - if idx < m_in { - x.get(idx).copied().ok_or_else(|| { - PiCcsError::InvalidInput(format!( - "witness index {idx} out of bounds for public input len={m_in}" - )) - }) - } else { - let w_idx = idx - m_in; - w.get(w_idx).copied().ok_or_else(|| { - PiCcsError::InvalidInput(format!( - "witness index {idx} (w[{w_idx}]) out of bounds for witness len={}", - w.len() - )) - }) - } - }; - - for j in 0..self.layout.chunk_size { - if z_at(self.layout.is_active(j))? == F::ZERO { - continue; - } - for inst in &self.layout.bus.shout_cols { - for lane in &inst.lanes { - let col = self.layout.bus.bus_cell(lane.has_lookup, j); - if z_at(col)? != F::ZERO { - count += 1; - } - } - } - } - } - Ok(count) - } - - pub fn mem_layouts(&self) -> &HashMap { - &self.mem_layouts - } - - pub fn initial_mem(&self) -> &HashMap<(u32, u64), F> { - &self.initial_mem - } - - pub fn prove_duration(&self) -> Duration { - self.prove_duration - } - - pub fn verify_duration(&self) -> Option { - self.verify_duration - } -} diff --git a/crates/neo-fold/src/riscv_trace_shard.rs b/crates/neo-fold/src/riscv_trace_shard.rs new file mode 100644 index 00000000..75ec868f --- /dev/null +++ b/crates/neo-fold/src/riscv_trace_shard.rs @@ -0,0 +1,1542 @@ +//! Convenience runner for RV32 trace-wiring CCS (time-in-rows). +//! +//! This is an ergonomic wrapper around the existing trace wiring artifacts: +//! - `neo_memory::riscv::trace` for execution-table extraction, and +//! - `neo_memory::riscv::ccs::trace` for fixed-width trace wiring CCS. +//! +//! The runner intentionally targets the current Tier 2.1 scope: +//! - fixed-width trace-wiring CCS steps with PROG/REG/RAM sidecar instances, +//! - no decode/semantics sidecar proofs in this wrapper yet. + +#![allow(non_snake_case)] + +use std::collections::{HashMap, HashSet}; +use std::marker::PhantomData; +use std::time::Duration; + +use crate::output_binding::OutputBindingConfig; +use crate::pi_ccs::FoldingMode; +use crate::session::FoldingSession; +use crate::shard::{ShardProof, StepLinkingConfig}; +use crate::PiCcsError; +use neo_ajtai::AjtaiSModule; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_ccs::CcsStructure; +use neo_math::{F, K}; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::output_check::ProgramIO; +use neo_memory::plain::{LutTable, PlainMemLayout}; +use neo_memory::riscv::ccs::{ + build_rv32_trace_wiring_ccs, build_rv32_trace_wiring_ccs_with_reserved_rows, + rv32_trace_ccs_witness_from_exec_table, rv32_trace_shared_bus_requirements_with_specs, + rv32_trace_shared_cpu_bus_config_with_specs, Rv32TraceCcsLayout, TraceShoutBusSpec, +}; +use neo_memory::riscv::exec_table::{Rv32ExecRow, Rv32ExecTable}; +use neo_memory::riscv::lookups::{ + decode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, RAM_ID, REG_ID, +}; +use neo_memory::riscv::rom_init::prog_rom_layout_and_init_words; +use neo_memory::riscv::trace::{ + extract_twist_lanes_over_time, rv32_decode_lookup_backed_cols, rv32_decode_lookup_backed_row_from_instr_word, + rv32_decode_lookup_table_id_for_col, rv32_width_lookup_backed_cols, rv32_width_lookup_table_id_for_col, + rv32_width_sidecar_witness_from_exec_table, Rv32DecodeSidecarLayout, Rv32WidthSidecarLayout, TwistLaneOverTime, +}; +use neo_memory::witness::{LutInstance, LutWitness, MemInstance, MemWitness, StepWitnessBundle}; +use neo_memory::{LutTableSpec, MemInit, R1csCpu}; +use neo_params::NeoParams; +use neo_vm_trace::{ShoutEvent, ShoutId, StepTrace, Twist as _, TwistOpKind, VmTrace}; +use p3_field::PrimeCharacteristicRing; +use p3_field::PrimeField64; + +#[cfg(target_arch = "wasm32")] +use js_sys::Date; +#[cfg(not(target_arch = "wasm32"))] +use std::time::Instant; + +#[cfg(target_arch = "wasm32")] +type TimePoint = f64; +#[cfg(not(target_arch = "wasm32"))] +type TimePoint = Instant; + +#[inline] +fn time_now() -> TimePoint { + #[cfg(target_arch = "wasm32")] + { + Date::now() + } + #[cfg(not(target_arch = "wasm32"))] + { + Instant::now() + } +} + +#[inline] +fn elapsed_duration(start: TimePoint) -> Duration { + #[cfg(target_arch = "wasm32")] + { + let elapsed_ms = Date::now() - start; + Duration::from_secs_f64(elapsed_ms / 1_000.0) + } + #[cfg(not(target_arch = "wasm32"))] + { + start.elapsed() + } +} + +/// Hard instruction cap for trace-wiring mode (Option C). +const DEFAULT_RV32_TRACE_MAX_STEPS: usize = 1 << 20; + +/// Default per-step trace rows for trace-mode IVC. +/// +/// The full trace is split into fixed-size chunks of this row count (except when the whole +/// trace is smaller), and those chunks are folded with step-linking. +const DEFAULT_RV32_TRACE_CHUNK_ROWS: usize = 1 << 16; + +fn max_ram_addr_from_exec(exec: &Rv32ExecTable) -> Option { + exec.rows + .iter() + .filter(|r| r.active) + .flat_map(|r| r.ram_events.iter().map(|e| e.addr)) + .max() +} + +fn required_bits_for_max_addr(max_addr: u64) -> usize { + if max_addr == 0 { + 1 + } else { + (u64::BITS - max_addr.leading_zeros()) as usize + } +} + +fn build_twist_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[TwistLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if x_prefix.len() != m_in { + return Err(format!( + "build_twist_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_twist_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::empty::<(usize, usize)>(), + core::iter::once((ell_addr, lanes)), + )?; + if bus.twist_cols.len() != 1 || !bus.shout_cols.is_empty() { + return Err("build_twist_only_bus_z: expected 1 twist instance and 0 shout instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let twist = &bus.twist_cols[0]; + for (lane_idx, cols) in twist.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_read.len() != t || lane.has_write.len() != t { + return Err("build_twist_only_bus_z: lane length mismatch".into()); + } + for j in 0..t { + let has_r = lane.has_read[j]; + let has_w = lane.has_write[j]; + + z[bus.bus_cell(cols.has_read, j)] = if has_r { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.has_write, j)] = if has_w { F::ONE } else { F::ZERO }; + + z[bus.bus_cell(cols.rv, j)] = if has_r { F::from_u64(lane.rv[j]) } else { F::ZERO }; + z[bus.bus_cell(cols.wv, j)] = if has_w { F::from_u64(lane.wv[j]) } else { F::ZERO }; + z[bus.bus_cell(cols.inc, j)] = if has_w { lane.inc_at_write_addr[j] } else { F::ZERO }; + + for (bit_idx, col_id) in cols.ra_bits.clone().enumerate() { + let bit_is_set = bit_idx < (u64::BITS as usize) && ((lane.ra[j] >> bit_idx) & 1) == 1; + z[bus.bus_cell(col_id, j)] = if bit_is_set { F::ONE } else { F::ZERO }; + } + for (bit_idx, col_id) in cols.wa_bits.clone().enumerate() { + let bit_is_set = bit_idx < (u64::BITS as usize) && ((lane.wa[j] >> bit_idx) & 1) == 1; + z[bus.bus_cell(col_id, j)] = if bit_is_set { F::ONE } else { F::ZERO }; + } + } + } + + Ok(z) +} + +fn mem_init_from_u64_sparse(sparse: &HashMap, k: usize, label: &str) -> Result, PiCcsError> { + let mut pairs = Vec::<(u64, F)>::new(); + for (&addr, &value) in sparse { + let addr_usize = usize::try_from(addr) + .map_err(|_| PiCcsError::InvalidInput(format!("{label} init addr does not fit usize: addr={addr}")))?; + if addr_usize >= k { + return Err(PiCcsError::InvalidInput(format!( + "{label} init addr out of range: addr={addr} >= k={k}" + ))); + } + if value != 0 { + pairs.push((addr, F::from_u64(value))); + } + } + pairs.sort_by_key(|(addr, _)| *addr); + Ok(if pairs.is_empty() { + MemInit::Zero + } else { + MemInit::Sparse(pairs) + }) +} + +fn final_reg_state_dense(exec: &Rv32ExecTable, reg_init: &HashMap) -> Result, PiCcsError> { + let mut regs = [0u64; 32]; + for (®, &value) in reg_init { + if reg >= 32 { + return Err(PiCcsError::InvalidInput(format!( + "reg_init_u32: register index out of range: reg={reg} (expected 0..32)" + ))); + } + if reg == 0 && value != 0 { + return Err(PiCcsError::InvalidInput( + "reg_init_u32: x0 must be 0 (non-zero init is forbidden)".into(), + )); + } + regs[reg as usize] = value as u32 as u64; + } + regs[0] = 0; + + for r in exec.rows.iter().filter(|r| r.active) { + if let Some(w) = &r.reg_write_lane0 { + if w.addr >= 32 { + return Err(PiCcsError::InvalidInput(format!( + "trace register write addr out of range at cycle {}: addr={}", + r.cycle, w.addr + ))); + } + if w.addr == 0 { + return Err(PiCcsError::InvalidInput(format!( + "trace writes x0 at cycle {} which is invalid", + r.cycle + ))); + } + regs[w.addr as usize] = w.value as u32 as u64; + regs[0] = 0; + } + } + + Ok(regs.iter().map(|&v| F::from_u64(v)).collect()) +} + +fn final_ram_state_dense(exec: &Rv32ExecTable, ram_init: &HashMap, k: usize) -> Result, PiCcsError> { + let mut out = vec![F::ZERO; k]; + for (&addr, &value) in ram_init { + let addr_usize = usize::try_from(addr) + .map_err(|_| PiCcsError::InvalidInput(format!("ram_init_u32: addr does not fit usize: addr={addr}")))?; + if addr_usize >= k { + return Err(PiCcsError::InvalidInput(format!( + "ram_init_u32: addr out of range for output binding domain: addr={addr} >= k={k}" + ))); + } + out[addr_usize] = F::from_u64(value as u32 as u64); + } + + for r in exec.rows.iter().filter(|r| r.active) { + for e in &r.ram_events { + if e.kind != TwistOpKind::Write { + continue; + } + let addr_usize = usize::try_from(e.addr).map_err(|_| { + PiCcsError::InvalidInput(format!( + "trace RAM write addr does not fit usize at cycle {}: addr={}", + r.cycle, e.addr + )) + })?; + if addr_usize >= k { + return Err(PiCcsError::InvalidInput(format!( + "trace RAM write addr out of range for output binding domain at cycle {}: addr={} >= k={k}", + r.cycle, e.addr + ))); + } + out[addr_usize] = F::from_u64(e.value as u32 as u64); + } + } + + Ok(out) +} + +fn init_reg_state(reg_init: &HashMap) -> Result<[u64; 32], PiCcsError> { + let mut regs = [0u64; 32]; + for (®, &value) in reg_init { + if reg >= 32 { + return Err(PiCcsError::InvalidInput(format!( + "reg_init_u32: register index out of range: reg={reg} (expected 0..32)" + ))); + } + if reg == 0 && value != 0 { + return Err(PiCcsError::InvalidInput( + "reg_init_u32: x0 must be 0 (non-zero init is forbidden)".into(), + )); + } + regs[reg as usize] = value as u32 as u64; + } + regs[0] = 0; + Ok(regs) +} + +fn init_ram_state(ram_init: &HashMap, ram_ell_addr: usize) -> Result, PiCcsError> { + if ram_ell_addr > 64 { + return Err(PiCcsError::InvalidInput(format!( + "RAM ell_addr too large for u64 addressing: ell_addr={ram_ell_addr}" + ))); + } + + let mut ram = HashMap::::new(); + for (&addr, &value) in ram_init { + if ram_ell_addr < 64 && (addr >> ram_ell_addr) != 0 { + return Err(PiCcsError::InvalidInput(format!( + "RAM init addr out of range for ell_addr={ram_ell_addr}: addr={addr}" + ))); + } + let v = value as u32 as u64; + if v != 0 { + ram.insert(addr, v); + } + } + Ok(ram) +} + +fn reg_state_to_sparse_map(regs: &[u64; 32]) -> HashMap { + let mut out = HashMap::::new(); + for (idx, &value) in regs.iter().enumerate().skip(1) { + if value != 0 { + out.insert(idx as u64, value); + } + } + out +} + +fn apply_exec_chunk_writes_to_state( + chunk: &Rv32ExecTable, + regs: &mut [u64; 32], + ram: &mut HashMap, +) -> Result<(), PiCcsError> { + for r in chunk.rows.iter().filter(|r| r.active) { + if let Some(w) = &r.reg_write_lane0 { + if w.addr == 0 { + return Err(PiCcsError::InvalidInput(format!( + "trace writes x0 at cycle {} which is invalid", + r.cycle + ))); + } + if w.addr >= 32 { + return Err(PiCcsError::InvalidInput(format!( + "trace register write addr out of range at cycle {}: addr={}", + r.cycle, w.addr + ))); + } + regs[w.addr as usize] = w.value as u32 as u64; + regs[0] = 0; + } + + for e in &r.ram_events { + if e.kind != TwistOpKind::Write { + continue; + } + let value = e.value as u32 as u64; + if value == 0 { + ram.remove(&e.addr); + } else { + ram.insert(e.addr, value); + } + } + } + Ok(()) +} + +fn split_exec_into_fixed_chunks(exec: &Rv32ExecTable, chunk_rows: usize) -> Result, PiCcsError> { + if chunk_rows == 0 { + return Err(PiCcsError::InvalidInput("trace chunk_rows must be non-zero".into())); + } + if exec.rows.is_empty() { + return Err(PiCcsError::InvalidInput("trace execution table is empty".into())); + } + + if exec.rows.len() <= chunk_rows { + return Ok(vec![exec.clone()]); + } + + let mut out = Vec::::new(); + let total = exec.rows.len(); + let mut start = 0usize; + while start < total { + let end = (start + chunk_rows).min(total); + let mut rows = exec.rows[start..end].to_vec(); + if rows.len() < chunk_rows { + let last = rows + .last() + .ok_or_else(|| PiCcsError::InvalidInput("trace chunk unexpectedly empty".into()))? + .clone(); + let mut cycle = last.cycle; + let pad_pc = last.pc_after; + let pad_halted = last.halted; + while rows.len() < chunk_rows { + cycle = cycle + .checked_add(1) + .ok_or_else(|| PiCcsError::InvalidInput("cycle overflow while chunk-padding trace".into()))?; + rows.push(neo_memory::riscv::exec_table::Rv32ExecRow::inactive( + cycle, pad_pc, pad_halted, + )); + } + } + out.push(Rv32ExecTable { rows }); + start = end; + } + + Ok(out) +} + +fn rv32_trace_chunk_to_witness( + layout: Rv32TraceCcsLayout, +) -> Box]) -> Vec + Send + Sync> { + Box::new(move |chunk: &[StepTrace]| { + rv32_trace_chunk_to_witness_checked(&layout, chunk) + .unwrap_or_else(|e| panic!("rv32_trace_chunk_to_witness failed for chunk_len={}: {e}", chunk.len())) + }) +} + +fn rv32_trace_chunk_to_witness_checked( + layout: &Rv32TraceCcsLayout, + chunk: &[StepTrace], +) -> Result, String> { + if chunk.is_empty() { + return Err("trace chunk witness: chunk must contain at least one step".into()); + } + if chunk.len() > layout.t { + return Err(format!( + "trace chunk witness: chunk.len()={} exceeds layout.t={}", + chunk.len(), + layout.t + )); + } + + let mut rows = Vec::with_capacity(layout.t); + for step in chunk { + rows.push(Rv32ExecRow::from_step(step)?); + } + + let mut cycle = rows + .last() + .ok_or_else(|| "trace chunk witness: empty rows after conversion".to_string())? + .cycle; + let pad_pc = rows.last().expect("rows non-empty").pc_after; + let pad_halted = rows.last().expect("rows non-empty").halted; + while rows.len() < layout.t { + cycle = cycle + .checked_add(1) + .ok_or_else(|| "trace chunk witness: cycle overflow while padding".to_string())?; + rows.push(Rv32ExecRow::inactive(cycle, pad_pc, pad_halted)); + } + + let exec = Rv32ExecTable { rows }; + let (x, w) = rv32_trace_ccs_witness_from_exec_table(layout, &exec)?; + Ok(x.into_iter().chain(w).collect()) +} + +fn infer_required_trace_shout_opcodes(program: &[RiscvInstruction]) -> HashSet { + let mut ops = HashSet::new(); + // Required for shared wiring (address/PC arithmetic). + ops.insert(RiscvOpcode::Add); + for instr in program { + match instr { + RiscvInstruction::RAlu { op, .. } | RiscvInstruction::IAlu { op, .. } => { + ops.insert(*op); + } + RiscvInstruction::Branch { cond, .. } => { + ops.insert(cond.to_shout_opcode()); + } + // Address arithmetic in these classes uses ADD shout semantics. + RiscvInstruction::Load { .. } + | RiscvInstruction::Store { .. } + | RiscvInstruction::Jalr { .. } + | RiscvInstruction::Auipc { .. } => { + ops.insert(RiscvOpcode::Add); + } + _ => {} + } + } + ops +} + +fn program_requires_ram_sidecar(program: &[RiscvInstruction]) -> bool { + program.iter().any(|instr| { + matches!( + instr, + RiscvInstruction::Load { .. } + | RiscvInstruction::Store { .. } + | RiscvInstruction::LoadReserved { .. } + | RiscvInstruction::StoreConditional { .. } + | RiscvInstruction::Amo { .. } + ) + }) +} + +fn program_requires_width_lookup(program: &[RiscvInstruction]) -> bool { + program + .iter() + .any(|instr| matches!(instr, RiscvInstruction::Load { .. } | RiscvInstruction::Store { .. })) +} + +fn rv32_trace_table_specs(shout_ops: &HashSet) -> HashMap { + let shout = RiscvShoutTables::new(32); + let mut table_specs = HashMap::new(); + for &op in shout_ops { + let table_id = shout.opcode_to_id(op).0; + table_specs.insert(table_id, LutTableSpec::RiscvOpcode { opcode: op, xlen: 32 }); + } + table_specs +} + +fn build_rv32_decode_lookup_tables( + prog_layout: &PlainMemLayout, + prog_init_words: &HashMap<(u32, u64), F>, +) -> HashMap> { + let decode_layout = Rv32DecodeSidecarLayout::new(); + let decode_cols = rv32_decode_lookup_backed_cols(&decode_layout); + let mut out = HashMap::new(); + for &col_id in decode_cols.iter() { + let table_id = rv32_decode_lookup_table_id_for_col(col_id); + let mut content = vec![F::ZERO; prog_layout.k]; + for addr in 0..prog_layout.k { + let instr_word = prog_init_words + .get(&(PROG_ID.0, addr as u64)) + .copied() + .unwrap_or(F::ZERO) + .as_canonical_u64() as u32; + let row = rv32_decode_lookup_backed_row_from_instr_word(&decode_layout, instr_word, /*active=*/ true); + content[addr] = row[col_id]; + } + out.insert( + table_id, + LutTable { + table_id, + k: prog_layout.k, + d: prog_layout.d, + n_side: prog_layout.n_side, + content, + }, + ); + } + out +} + +fn inject_rv32_decode_lookup_events_into_trace( + trace: &mut VmTrace, + prog_layout: &PlainMemLayout, + prog_init_words: &HashMap<(u32, u64), F>, +) -> Result<(), PiCcsError> { + let decode_layout = Rv32DecodeSidecarLayout::new(); + let decode_cols = rv32_decode_lookup_backed_cols(&decode_layout); + for (step_idx, step) in trace.steps.iter_mut().enumerate() { + let prog_read = step + .twist_events + .iter() + .find(|e| e.twist_id == PROG_ID && e.kind == TwistOpKind::Read) + .ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "missing PROG read event while injecting decode lookup events at step {step_idx}" + )) + })?; + let addr = prog_read.addr; + if (addr as usize) >= prog_layout.k { + return Err(PiCcsError::ProtocolError(format!( + "decode lookup event addr out of range at step {step_idx}: addr={addr}, k={}", + prog_layout.k + ))); + } + let instr_word = prog_init_words + .get(&(PROG_ID.0, addr)) + .copied() + .unwrap_or_else(|| F::from_u64(prog_read.value)) + .as_canonical_u64() as u32; + let row = rv32_decode_lookup_backed_row_from_instr_word(&decode_layout, instr_word, /*active=*/ true); + for &col_id in decode_cols.iter() { + step.shout_events.push(ShoutEvent { + shout_id: ShoutId(rv32_decode_lookup_table_id_for_col(col_id)), + key: addr, + value: row[col_id].as_canonical_u64(), + }); + } + } + Ok(()) +} + +fn build_rv32_width_lookup_tables( + width_layout: &Rv32WidthSidecarLayout, + exec: &Rv32ExecTable, + trace_steps: usize, +) -> Result<(HashMap>, usize), PiCcsError> { + // Width lookup tables here are execution-indexed helper transport tables. + // They are not a standalone trust root: Route-A width residual claims bind + // every opened helper value back to committed trace columns (`ram_rv`, + // `rs2_val`), and WB/WP enforce the associated bitness/quiescence properties. + let max_cycle = exec + .rows + .iter() + .take(trace_steps) + .map(|r| r.cycle) + .max() + .unwrap_or(0); + let cycle_d = required_bits_for_max_addr(max_cycle).max(2); + let cycle_k = 1usize + .checked_shl(cycle_d as u32) + .ok_or_else(|| PiCcsError::InvalidInput(format!("width lookup cycle width too large: d={cycle_d}")))?; + + let wit = rv32_width_sidecar_witness_from_exec_table(width_layout, exec); + let width_cols = rv32_width_lookup_backed_cols(width_layout); + let mut out = HashMap::new(); + for &col_id in width_cols.iter() { + let table_id = rv32_width_lookup_table_id_for_col(col_id); + let mut content = vec![F::ZERO; cycle_k]; + for (i, row) in exec.rows.iter().enumerate().take(trace_steps) { + let cycle = row.cycle as usize; + if cycle >= cycle_k { + return Err(PiCcsError::ProtocolError(format!( + "width lookup cycle out of range at row {i}: cycle={}, k={cycle_k}", + row.cycle + ))); + } + content[cycle] = wit.cols[col_id][i]; + } + out.insert( + table_id, + LutTable { + table_id, + k: cycle_k, + d: cycle_d, + n_side: 2, + content, + }, + ); + } + Ok((out, cycle_d)) +} + +fn inject_rv32_width_lookup_events_into_trace( + trace: &mut VmTrace, + exec: &Rv32ExecTable, + width_layout: &Rv32WidthSidecarLayout, +) -> Result<(), PiCcsError> { + if trace.steps.len() > exec.rows.len() { + return Err(PiCcsError::ProtocolError(format!( + "width lookup injection drift: trace steps {} > exec rows {}", + trace.steps.len(), + exec.rows.len() + ))); + } + let wit = rv32_width_sidecar_witness_from_exec_table(width_layout, exec); + let width_cols = rv32_width_lookup_backed_cols(width_layout); + for (i, step) in trace.steps.iter_mut().enumerate() { + let cycle = exec + .rows + .get(i) + .ok_or_else(|| PiCcsError::ProtocolError("missing exec row while injecting width lookups".into()))? + .cycle; + for &col_id in width_cols.iter() { + step.shout_events.push(ShoutEvent { + shout_id: ShoutId(rv32_width_lookup_table_id_for_col(col_id)), + key: cycle, + value: wit.cols[col_id][i].as_canonical_u64(), + }); + } + } + Ok(()) +} + +/// High-level builder for proving/verifying the RV32 trace wiring CCS. +/// +/// This path is intentionally narrow: +/// - builds a padded execution table, +/// - proves one or more trace-wiring CCS steps (IVC), +/// - verifies the resulting shard proof. +#[derive(Clone, Copy, Debug, Default)] +enum OutputTarget { + #[default] + Ram, + Reg, +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct Rv32TraceProvePhaseDurations { + pub setup: Duration, + pub chunk_build_commit: Duration, + pub fold_and_prove: Duration, +} + +#[derive(Clone, Debug)] +pub struct Rv32TraceWiring { + program_base: u64, + program_bytes: Vec, + xlen: usize, + max_steps: Option, + min_trace_len: usize, + chunk_rows: Option, + shared_cpu_bus: bool, + mode: FoldingMode, + ram_init: HashMap, + reg_init: HashMap, + output_claims: ProgramIO, + output_target: OutputTarget, + shout_ops: Option>, + extra_lut_table_specs: HashMap, + extra_shout_bus_specs: Vec, +} + +impl Rv32TraceWiring { + /// Create a trace runner from ROM bytes. + pub fn from_rom(program_base: u64, program_bytes: &[u8]) -> Self { + Self { + program_base, + program_bytes: program_bytes.to_vec(), + xlen: 32, + max_steps: None, + min_trace_len: 4, + chunk_rows: None, + shared_cpu_bus: true, + mode: FoldingMode::Optimized, + ram_init: HashMap::new(), + reg_init: HashMap::new(), + output_claims: ProgramIO::new(), + output_target: OutputTarget::Ram, + shout_ops: None, + extra_lut_table_specs: HashMap::new(), + extra_shout_bus_specs: Vec::new(), + } + } + + pub fn xlen(mut self, xlen: usize) -> Self { + self.xlen = xlen; + self + } + + /// Lower-bound for execution-table length. + /// + /// Final `t` is `max(trace_len, min_trace_len)`. + pub fn min_trace_len(mut self, min_trace_len: usize) -> Self { + self.min_trace_len = min_trace_len.max(1); + self + } + + /// Fixed rows per trace step for IVC folding. + /// + /// The trace is split into fixed-size chunks, each chunk is proven with the same step CCS, + /// and step-linking enforces `pc_final -> pc0`. + pub fn chunk_rows(mut self, chunk_rows: usize) -> Self { + self.chunk_rows = Some(chunk_rows); + self + } + + /// Toggle shared-CPU-bus trace proving mode. + /// + /// `true` is the intended production default; `false` keeps the legacy no-shared-bus path. + pub fn shared_cpu_bus(mut self, enabled: bool) -> Self { + self.shared_cpu_bus = enabled; + self + } + + /// Bound executed instruction count. + pub fn max_steps(mut self, max_steps: usize) -> Self { + self.max_steps = Some(max_steps); + self + } + + pub fn mode(mut self, mode: FoldingMode) -> Self { + self.mode = mode; + self + } + + /// Initialize RAM byte-addressed word cell to a u32 value. + pub fn ram_init_u32(mut self, addr: u64, value: u32) -> Self { + self.ram_init.insert(addr, value as u64); + self + } + + /// Initialize register `reg` (x0..x31) to a u32 value. + pub fn reg_init_u32(mut self, reg: u64, value: u32) -> Self { + self.reg_init.insert(reg, value as u64); + self + } + + pub fn output(mut self, output_addr: u64, expected_output: F) -> Self { + self.output_claims = ProgramIO::new().with_output(output_addr, expected_output); + self.output_target = OutputTarget::Ram; + self + } + + pub fn output_claim(mut self, addr: u64, value: F) -> Self { + if !matches!(self.output_target, OutputTarget::Ram) { + self.output_target = OutputTarget::Ram; + self.output_claims = ProgramIO::new(); + } + self.output_claims = self.output_claims.with_output(addr, value); + self + } + + pub fn reg_output(mut self, reg: u64, expected: F) -> Self { + self.output_claims = ProgramIO::new().with_output(reg, expected); + self.output_target = OutputTarget::Reg; + self + } + + pub fn reg_output_claim(mut self, reg: u64, expected: F) -> Self { + if !matches!(self.output_target, OutputTarget::Reg) { + self.output_target = OutputTarget::Reg; + self.output_claims = ProgramIO::new(); + } + self.output_claims = self.output_claims.with_output(reg, expected); + self + } + + /// Use the default program-inferred minimal shout set. + pub fn shout_auto_minimal(mut self) -> Self { + self.shout_ops = None; + self + } + + /// Optional override for shout tables. + /// + /// The override must be a superset of the program-inferred required shout set. + pub fn shout_ops(mut self, ops: impl IntoIterator) -> Self { + self.shout_ops = Some(ops.into_iter().collect()); + self + } + + /// Add an extra implicit lookup-table spec by `table_id`. + /// + /// The id must not collide with inferred opcode-table ids. + pub fn extra_lut_table_spec(mut self, table_id: u32, spec: LutTableSpec) -> Self { + self.extra_lut_table_specs.insert(table_id, spec); + self + } + + /// Optional extra Shout family geometry for trace shared-bus mode. + /// + /// Each spec adds/overrides a `table_id -> ell_addr` mapping used to size shout lanes. + pub fn extra_shout_bus_specs(mut self, specs: impl IntoIterator) -> Self { + self.extra_shout_bus_specs = specs.into_iter().collect(); + self + } + + pub fn prove(self) -> Result { + if self.xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 trace wiring runner requires xlen == 32 (got {})", + self.xlen + ))); + } + if self.program_base != 0 { + return Err(PiCcsError::InvalidInput( + "RV32 trace wiring runner requires program_base == 0".into(), + )); + } + if self.program_bytes.is_empty() { + return Err(PiCcsError::InvalidInput("program_bytes must be non-empty".into())); + } + if self.min_trace_len > DEFAULT_RV32_TRACE_MAX_STEPS { + return Err(PiCcsError::InvalidInput(format!( + "min_trace_len={} exceeds trace-mode hard cap {}. Increase chunk_rows and prove in chunks for longer executions.", + self.min_trace_len, DEFAULT_RV32_TRACE_MAX_STEPS + ))); + } + if self.program_bytes.len() % 4 != 0 { + return Err(PiCcsError::InvalidInput( + "program_bytes must be 4-byte aligned (RVC is not supported)".into(), + )); + } + for (i, chunk) in self.program_bytes.chunks_exact(4).enumerate() { + let first_half = u16::from_le_bytes([chunk[0], chunk[1]]); + if (first_half & 0b11) != 0b11 { + return Err(PiCcsError::InvalidInput(format!( + "compressed instruction encoding (RVC) is not supported at word index {i}" + ))); + } + } + + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + let using_default_max_steps = self.max_steps.is_none(); + let max_steps = match self.max_steps { + Some(n) => { + if n == 0 { + return Err(PiCcsError::InvalidInput("max_steps must be non-zero".into())); + } + if n > DEFAULT_RV32_TRACE_MAX_STEPS { + return Err(PiCcsError::InvalidInput(format!( + "max_steps={} exceeds trace-mode hard cap {}. Increase chunk_rows and prove in chunks for longer executions.", + n, DEFAULT_RV32_TRACE_MAX_STEPS + ))); + } + n + } + None => DEFAULT_RV32_TRACE_MAX_STEPS, + }; + if !self.shared_cpu_bus { + return Err(PiCcsError::InvalidInput( + "RV32 trace wiring no-shared fallback is removed; Phase 2 decode lookup requires shared_cpu_bus=true" + .into(), + )); + } + let ram_init_map = self.ram_init.clone(); + let reg_init_map = self.reg_init.clone(); + let output_claims = self.output_claims.clone(); + let output_target = self.output_target; + let (prog_layout, prog_init_words) = + prog_rom_layout_and_init_words::(PROG_ID, /*base_addr=*/ 0, &self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + + let mut vm = RiscvCpu::new(self.xlen); + vm.load_program(/*base=*/ 0, program.clone()); + + let mut twist = + RiscvMemory::with_program_in_twist(self.xlen, PROG_ID, /*base_addr=*/ 0, &self.program_bytes); + for (&addr, &value) in &ram_init_map { + twist.store(RAM_ID, addr, value as u32 as u64); + } + for (®, &value) in ®_init_map { + if reg >= 32 { + return Err(PiCcsError::InvalidInput(format!( + "reg_init_u32: register index out of range: reg={reg} (expected 0..32)" + ))); + } + if reg == 0 && value != 0 { + return Err(PiCcsError::InvalidInput( + "reg_init_u32: x0 must be 0 (non-zero init is forbidden)".into(), + )); + } + twist.store(REG_ID, reg, value as u32 as u64); + } + let shout = RiscvShoutTables::new(self.xlen); + + let mut trace = neo_vm_trace::trace_program(vm, twist, shout, max_steps) + .map_err(|e| PiCcsError::InvalidInput(format!("trace_program failed: {e}")))?; + + if using_default_max_steps && !trace.did_halt() { + return Err(PiCcsError::InvalidInput(format!( + "RV32 execution did not halt within max_steps={max_steps}; call .max_steps(...) to raise the limit or ensure the guest halts" + ))); + } + + let target_len = trace.steps.len().max(self.min_trace_len); + if target_len > DEFAULT_RV32_TRACE_MAX_STEPS { + return Err(PiCcsError::InvalidInput(format!( + "trace length {} exceeds trace-mode hard cap {}. Increase chunk_rows and prove in chunks for longer executions.", + target_len, DEFAULT_RV32_TRACE_MAX_STEPS + ))); + } + if self.shared_cpu_bus { + inject_rv32_decode_lookup_events_into_trace(&mut trace, &prog_layout, &prog_init_words)?; + } + let exec = Rv32ExecTable::from_trace_padded(&trace, target_len) + .map_err(|e| PiCcsError::InvalidInput(format!("Rv32ExecTable::from_trace_padded failed: {e}")))?; + exec.validate_cycle_chain() + .map_err(|e| PiCcsError::InvalidInput(format!("validate_cycle_chain failed: {e}")))?; + exec.validate_pc_chain() + .map_err(|e| PiCcsError::InvalidInput(format!("validate_pc_chain failed: {e}")))?; + exec.validate_halted_tail() + .map_err(|e| PiCcsError::InvalidInput(format!("validate_halted_tail failed: {e}")))?; + exec.validate_inactive_rows_are_empty() + .map_err(|e| PiCcsError::InvalidInput(format!("validate_inactive_rows_are_empty failed: {e}")))?; + let width_layout = Rv32WidthSidecarLayout::new(); + let include_width_lookup = self.shared_cpu_bus && program_requires_width_lookup(&program); + let (width_lookup_tables, width_lookup_addr_d) = if include_width_lookup { + let (tables, addr_d) = build_rv32_width_lookup_tables(&width_layout, &exec, trace.steps.len())?; + inject_rv32_width_lookup_events_into_trace(&mut trace, &exec, &width_layout)?; + (tables, addr_d) + } else { + (HashMap::new(), 0usize) + }; + + let requested_chunk_rows = self.chunk_rows.unwrap_or(DEFAULT_RV32_TRACE_CHUNK_ROWS); + if requested_chunk_rows == 0 { + return Err(PiCcsError::InvalidInput("trace chunk_rows must be non-zero".into())); + } + let step_rows = requested_chunk_rows.min(exec.rows.len().max(1)); + let exec_chunks = split_exec_into_fixed_chunks(&exec, step_rows)?; + + let mut layout = Rv32TraceCcsLayout::new(step_rows) + .map_err(|e| PiCcsError::InvalidInput(format!("Rv32TraceCcsLayout::new failed: {e}")))?; + + let prove_start = time_now(); + let setup_start = prove_start; + + let mut max_ram_addr = max_ram_addr_from_exec(&exec).unwrap_or(0); + if let Some(max_init_addr) = ram_init_map.keys().copied().max() { + max_ram_addr = max_ram_addr.max(max_init_addr); + } + let wants_ram_output = matches!(output_target, OutputTarget::Ram) && !output_claims.is_empty(); + if matches!(output_target, OutputTarget::Ram) { + if let Some(max_claim_addr) = output_claims.claimed_addresses().max() { + max_ram_addr = max_ram_addr.max(max_claim_addr); + } + } + let ram_d = required_bits_for_max_addr(max_ram_addr).max(2); + let ram_k = 1usize + .checked_shl(ram_d as u32) + .ok_or_else(|| PiCcsError::InvalidInput(format!("RAM address width too large: d={ram_d}")))?; + // Track A used-set derivation must be deterministic from public inputs/config. + // Do not derive RAM inclusion from runtime witness/events. + let include_ram_sidecar = + program_requires_ram_sidecar(&program) || !ram_init_map.is_empty() || wants_ram_output; + + let mut mem_layouts: HashMap = HashMap::from([ + ( + REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + (PROG_ID.0, prog_layout.clone()), + ]); + if include_ram_sidecar { + mem_layouts.insert( + RAM_ID.0, + PlainMemLayout { + k: ram_k, + d: ram_d, + n_side: 2, + lanes: 1, + }, + ); + } + + let inferred_shout_ops = infer_required_trace_shout_opcodes(&program); + let shout_ops = match &self.shout_ops { + Some(override_ops) => { + let missing: HashSet = inferred_shout_ops + .difference(override_ops) + .copied() + .collect(); + if !missing.is_empty() { + let mut missing_names: Vec = missing.into_iter().map(|op| format!("{op:?}")).collect(); + missing_names.sort_unstable(); + return Err(PiCcsError::InvalidInput(format!( + "trace shout_ops override must be a superset of required opcodes; missing [{}]", + missing_names.join(", ") + ))); + } + override_ops.clone() + } + None => inferred_shout_ops, + }; + let decode_layout = Rv32DecodeSidecarLayout::new(); + let decode_lookup_tables = if self.shared_cpu_bus { + build_rv32_decode_lookup_tables(&prog_layout, &prog_init_words) + } else { + HashMap::new() + }; + let decode_lookup_bus_specs: Vec = if self.shared_cpu_bus { + let decode_lookup_cols = rv32_decode_lookup_backed_cols(&decode_layout); + decode_lookup_cols + .iter() + .copied() + .map(|col_id| TraceShoutBusSpec { + table_id: rv32_decode_lookup_table_id_for_col(col_id), + ell_addr: prog_layout.d, + n_vals: 1usize, + }) + .collect() + } else { + Vec::new() + }; + let width_lookup_bus_specs: Vec = if include_width_lookup { + let width_lookup_cols = rv32_width_lookup_backed_cols(&width_layout); + width_lookup_cols + .iter() + .copied() + .map(|col_id| TraceShoutBusSpec { + table_id: rv32_width_lookup_table_id_for_col(col_id), + ell_addr: width_lookup_addr_d, + n_vals: 1usize, + }) + .collect() + } else { + Vec::new() + }; + let mut table_specs = rv32_trace_table_specs(&shout_ops); + let mut base_shout_table_ids: Vec = table_specs.keys().copied().collect(); + base_shout_table_ids.sort_unstable(); + for (&table_id, spec) in &self.extra_lut_table_specs { + if table_specs.contains_key(&table_id) { + return Err(PiCcsError::InvalidInput(format!( + "extra_lut_table_spec collides with existing table_id={table_id}" + ))); + } + table_specs.insert(table_id, spec.clone()); + } + let mut all_extra_shout_specs = self.extra_shout_bus_specs.clone(); + all_extra_shout_specs.extend(decode_lookup_bus_specs.clone()); + all_extra_shout_specs.extend(width_lookup_bus_specs.clone()); + for spec in &all_extra_shout_specs { + if !table_specs.contains_key(&spec.table_id) + && !decode_lookup_tables.contains_key(&spec.table_id) + && !width_lookup_tables.contains_key(&spec.table_id) + { + return Err(PiCcsError::InvalidInput(format!( + "extra_shout_bus_specs includes table_id={} without a table spec/table content", + spec.table_id + ))); + } + } + + let mut ccs_reserved_rows = 0usize; + if self.shared_cpu_bus { + let (bus_region_len, reserved_rows) = rv32_trace_shared_bus_requirements_with_specs( + &layout, + &base_shout_table_ids, + &all_extra_shout_specs, + &mem_layouts, + ) + .map_err(|e| { + PiCcsError::InvalidInput(format!("rv32_trace_shared_bus_requirements_with_specs failed: {e}")) + })?; + layout.m = layout + .m + .checked_add(bus_region_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace layout m overflow after bus tail reservation".into()))?; + ccs_reserved_rows = reserved_rows; + } + + let mut ccs = if ccs_reserved_rows == 0 { + build_rv32_trace_wiring_ccs(&layout) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_trace_wiring_ccs failed: {e}")))? + } else { + build_rv32_trace_wiring_ccs_with_reserved_rows(&layout, ccs_reserved_rows).map_err(|e| { + PiCcsError::InvalidInput(format!("build_rv32_trace_wiring_ccs_with_reserved_rows failed: {e}")) + })? + }; + + let mut session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs)?; + session.set_step_linking(StepLinkingConfig::new(vec![(layout.pc_final, layout.pc0)])); + + let mut prog_init_pairs: Vec<(u64, F)> = prog_init_words + .into_iter() + .filter_map(|((mem_id, addr), value)| (mem_id == PROG_ID.0 && value != F::ZERO).then_some((addr, value))) + .collect(); + prog_init_pairs.sort_by_key(|(addr, _)| *addr); + let prog_mem_init = if prog_init_pairs.is_empty() { + MemInit::Zero + } else { + MemInit::Sparse(prog_init_pairs) + }; + let mut initial_mem: HashMap<(u32, u64), F> = HashMap::new(); + if let MemInit::Sparse(pairs) = &prog_mem_init { + for &(addr, value) in pairs { + if value != F::ZERO { + initial_mem.insert((PROG_ID.0, addr), value); + } + } + } + for (®, &value) in ®_init_map { + let v = F::from_u64(value as u32 as u64); + if v != F::ZERO { + initial_mem.insert((REG_ID.0, reg), v); + } + } + for (&addr, &value) in &ram_init_map { + let v = F::from_u64(value as u32 as u64); + if v != F::ZERO { + initial_mem.insert((RAM_ID.0, addr), v); + } + } + + let setup_duration = elapsed_duration(setup_start); + let mut chunk_build_commit_duration = Duration::ZERO; + if self.shared_cpu_bus { + let chunk_start = time_now(); + + let mut lut_tables = decode_lookup_tables.clone(); + lut_tables.extend(width_lookup_tables.clone()); + let lut_lanes: HashMap = HashMap::new(); + + let mut cpu = R1csCpu::new( + ccs.clone(), + session.params().clone(), + session.committer().clone(), + layout.m_in, + &lut_tables, + &table_specs, + rv32_trace_chunk_to_witness(layout.clone()), + ) + .map_err(|e| PiCcsError::InvalidInput(format!("R1csCpu::new failed: {e}")))?; + cpu = cpu + .with_shared_cpu_bus( + rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &base_shout_table_ids, + &all_extra_shout_specs, + mem_layouts.clone(), + initial_mem.clone(), + ) + .map_err(|e| { + PiCcsError::InvalidInput(format!("rv32_trace_shared_cpu_bus_config_with_specs failed: {e}")) + })?, + layout.t, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + + ccs = cpu.ccs.clone(); + + session.execute_shard_shared_cpu_bus_from_trace( + &trace, + max_steps, + layout.t, + &mem_layouts, + &lut_tables, + &table_specs, + &lut_lanes, + &initial_mem, + &cpu, + )?; + + if session.steps_witness().len() != exec_chunks.len() { + return Err(PiCcsError::ProtocolError(format!( + "shared trace build drift: step bundle count {} != exec chunk count {}", + session.steps_witness().len(), + exec_chunks.len() + ))); + } + chunk_build_commit_duration += elapsed_duration(chunk_start); + } else { + // Route-A legacy fallback: keep the main CPU witness as pure trace columns (no bus tail), + // and attach PROG/REG/RAM as separately committed no-shared-bus Twist instances linked at r_time. + let mut reg_state = init_reg_state(®_init_map)?; + let mut ram_state = init_ram_state(&ram_init_map, ram_d)?; + for exec_chunk in &exec_chunks { + let chunk_start = time_now(); + let reg_init_chunk = reg_state_to_sparse_map(®_state); + let ram_init_chunk = ram_state.clone(); + + let reg_mem_init = mem_init_from_u64_sparse(®_init_chunk, 32, "REG")?; + let ram_mem_init = mem_init_from_u64_sparse(&ram_init_chunk, ram_k, "RAM")?; + let twist_lanes = extract_twist_lanes_over_time(exec_chunk, ®_init_chunk, &ram_init_chunk, ram_d) + .map_err(|e| PiCcsError::InvalidInput(format!("extract_twist_lanes_over_time failed: {e}")))?; + + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, exec_chunk).map_err(|e| { + PiCcsError::InvalidInput(format!("rv32_trace_ccs_witness_from_exec_table failed: {e}")) + })?; + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(session.params(), &z_cpu); + let c_cpu = session.committer().commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + let prog_mem_inst = MemInstance { + mem_id: PROG_ID.0, + comms: Vec::new(), + k: prog_layout.k, + d: prog_layout.d, + n_side: prog_layout.n_side, + steps: layout.t, + lanes: 1, + ell: 1, + init: prog_mem_init.clone(), + }; + let reg_mem_inst = MemInstance { + mem_id: REG_ID.0, + comms: Vec::new(), + k: 32, + d: 5, + n_side: 2, + steps: layout.t, + lanes: 2, + ell: 1, + init: reg_mem_init, + }; + let ram_mem_inst = include_ram_sidecar.then_some(MemInstance { + mem_id: RAM_ID.0, + comms: Vec::new(), + k: ram_k, + d: ram_d, + n_side: 2, + steps: layout.t, + lanes: 1, + ell: 1, + init: ram_mem_init, + }); + + let prog_z = build_twist_only_bus_z( + ccs.m, + layout.m_in, + layout.t, + prog_mem_inst.d * prog_mem_inst.ell, + prog_mem_inst.lanes, + std::slice::from_ref(&twist_lanes.prog), + &x, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("build PROG twist z failed: {e}")))?; + let prog_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(session.params(), &prog_z); + let prog_c = session.committer().commit(&prog_Z); + let prog_mem_inst = MemInstance { + comms: vec![prog_c], + ..prog_mem_inst + }; + let prog_mem_wit = MemWitness { mats: vec![prog_Z] }; + + let reg_z = build_twist_only_bus_z( + ccs.m, + layout.m_in, + layout.t, + reg_mem_inst.d * reg_mem_inst.ell, + reg_mem_inst.lanes, + &[twist_lanes.reg_lane0.clone(), twist_lanes.reg_lane1.clone()], + &x, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("build REG twist z failed: {e}")))?; + let reg_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(session.params(), ®_z); + let reg_c = session.committer().commit(®_Z); + let reg_mem_inst = MemInstance { + comms: vec![reg_c], + ..reg_mem_inst + }; + let reg_mem_wit = MemWitness { mats: vec![reg_Z] }; + + let ram_mem = if let Some(ram_mem_inst) = ram_mem_inst { + let ram_z = build_twist_only_bus_z( + ccs.m, + layout.m_in, + layout.t, + ram_mem_inst.d * ram_mem_inst.ell, + ram_mem_inst.lanes, + std::slice::from_ref(&twist_lanes.ram), + &x, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("build RAM twist z failed: {e}")))?; + let ram_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(session.params(), &ram_z); + let ram_c = session.committer().commit(&ram_Z); + let ram_mem_inst = MemInstance { + comms: vec![ram_c], + ..ram_mem_inst + }; + Some((ram_mem_inst, MemWitness { mats: vec![ram_Z] })) + } else { + None + }; + + let mut mem_instances = vec![(prog_mem_inst, prog_mem_wit), (reg_mem_inst, reg_mem_wit)]; + if let Some(ram_mem) = ram_mem { + mem_instances.push(ram_mem); + } + + session.add_step_bundle(StepWitnessBundle { + mcs, + lut_instances: Vec::<(LutInstance<_, _>, LutWitness)>::new(), + mem_instances, + _phantom: PhantomData::, + }); + + apply_exec_chunk_writes_to_state(exec_chunk, &mut reg_state, &mut ram_state)?; + chunk_build_commit_duration += elapsed_duration(chunk_start); + } + } + + let mem_order = session + .steps_public() + .first() + .map(|s| { + s.mem_insts + .iter() + .map(|inst| inst.mem_id) + .collect::>() + }) + .unwrap_or_default(); + let ram_ob_mem_idx = if wants_ram_output { + Some( + mem_order + .iter() + .position(|&id| id == RAM_ID.0) + .ok_or_else(|| PiCcsError::ProtocolError("missing RAM mem instance for output binding".into()))?, + ) + } else { + None + }; + let reg_ob_mem_idx = mem_order + .iter() + .position(|&id| id == REG_ID.0) + .ok_or_else(|| PiCcsError::ProtocolError("missing REG mem instance for output binding".into()))?; + + let fold_start = time_now(); + let (proof, output_binding_cfg) = if output_claims.is_empty() { + (session.fold_and_prove(&ccs)?, None) + } else { + let (ob_mem_idx, ob_num_bits, final_memory_state) = match output_target { + OutputTarget::Ram => ( + ram_ob_mem_idx.ok_or_else(|| { + PiCcsError::ProtocolError("missing RAM mem instance for output binding".into()) + })?, + ram_d, + final_ram_state_dense(&exec, &ram_init_map, ram_k)?, + ), + OutputTarget::Reg => (reg_ob_mem_idx, 5usize, final_reg_state_dense(&exec, ®_init_map)?), + }; + let ob_cfg = OutputBindingConfig::new(ob_num_bits, output_claims).with_mem_idx(ob_mem_idx); + let proof = session.fold_and_prove_with_output_binding_simple(&ccs, &ob_cfg, &final_memory_state)?; + (proof, Some(ob_cfg)) + }; + let fold_and_prove_duration = elapsed_duration(fold_start); + let prove_duration = elapsed_duration(prove_start); + let prove_phase_durations = Rv32TraceProvePhaseDurations { + setup: setup_duration, + chunk_build_commit: chunk_build_commit_duration, + fold_and_prove: fold_and_prove_duration, + }; + + let mut used_mem_ids: Vec = mem_layouts.keys().copied().collect(); + used_mem_ids.sort_unstable(); + let mut used_shout_table_ids = base_shout_table_ids.clone(); + for spec in &all_extra_shout_specs { + if !used_shout_table_ids.contains(&spec.table_id) { + used_shout_table_ids.push(spec.table_id); + } + } + used_shout_table_ids.sort_unstable(); + + Ok(Rv32TraceWiringRun { + session, + ccs, + layout, + exec, + proof, + used_mem_ids, + used_shout_table_ids, + output_binding_cfg, + prove_duration, + prove_phase_durations, + verify_duration: None, + }) + } +} + +/// Completed trace-wiring proof run. +pub struct Rv32TraceWiringRun { + session: FoldingSession, + ccs: CcsStructure, + layout: Rv32TraceCcsLayout, + exec: Rv32ExecTable, + proof: ShardProof, + used_mem_ids: Vec, + used_shout_table_ids: Vec, + output_binding_cfg: Option, + prove_duration: Duration, + prove_phase_durations: Rv32TraceProvePhaseDurations, + verify_duration: Option, +} + +impl Rv32TraceWiringRun { + pub fn params(&self) -> &NeoParams { + self.session.params() + } + + pub fn committer(&self) -> &AjtaiSModule { + self.session.committer() + } + + pub fn ccs(&self) -> &CcsStructure { + &self.ccs + } + + pub fn layout(&self) -> &Rv32TraceCcsLayout { + &self.layout + } + + pub fn exec_table(&self) -> &Rv32ExecTable { + &self.exec + } + + pub fn proof(&self) -> &ShardProof { + &self.proof + } + + /// Auto-derived memory sidecar IDs used by this run (`S_memory`). + pub fn used_memory_ids(&self) -> &[u32] { + &self.used_mem_ids + } + + /// Auto-derived shout lookup table IDs used by this run (`S_lookup`). + pub fn used_shout_table_ids(&self) -> &[u32] { + &self.used_shout_table_ids + } + + pub fn verify_proof(&self, proof: &ShardProof) -> Result<(), PiCcsError> { + let ok = match &self.output_binding_cfg { + None => self.session.verify_collected(&self.ccs, proof)?, + Some(cfg) => self + .session + .verify_with_output_binding_collected_simple(&self.ccs, proof, cfg)?, + }; + if !ok { + return Err(PiCcsError::ProtocolError("verification failed".into())); + } + Ok(()) + } + + pub fn verify(&mut self) -> Result<(), PiCcsError> { + let verify_start = time_now(); + self.verify_proof(&self.proof)?; + self.verify_duration = Some(elapsed_duration(verify_start)); + Ok(()) + } + + pub fn ccs_num_constraints(&self) -> usize { + self.ccs.n + } + + pub fn ccs_num_variables(&self) -> usize { + self.ccs.m + } + + /// Number of real (active) rows in the unpadded trace. + pub fn trace_len(&self) -> usize { + self.exec.rows.iter().filter(|r| r.active).count() + } + + /// Number of collected folding steps. + pub fn fold_count(&self) -> usize { + self.proof.steps.len() + } + + pub fn prove_duration(&self) -> Duration { + self.prove_duration + } + + pub fn prove_phase_durations(&self) -> Rv32TraceProvePhaseDurations { + self.prove_phase_durations + } + + pub fn verify_duration(&self) -> Option { + self.verify_duration + } + + pub fn steps_public(&self) -> Vec> { + self.session.steps_public() + } +} diff --git a/crates/neo-fold/src/session.rs b/crates/neo-fold/src/session.rs index 83c86ae8..9f7af36a 100644 --- a/crates/neo-fold/src/session.rs +++ b/crates/neo-fold/src/session.rs @@ -36,7 +36,10 @@ use neo_ccs::{CcsStructure, Mat, McsInstance, McsWitness, MeInstance}; use neo_math::ring::Rq as RqEl; use neo_math::{D, F, K}; use neo_memory::ajtai::encode_vector_balanced_to_mat; -use neo_memory::builder::{build_shard_witness_shared_cpu_bus_with_aux, CpuArithmetization, ShardWitnessAux}; +use neo_memory::builder::{ + build_shard_witness_shared_cpu_bus_from_trace_with_aux, build_shard_witness_shared_cpu_bus_with_aux, + CpuArithmetization, ShardWitnessAux, +}; use neo_memory::plain::{LutTable, PlainMemLayout}; use neo_memory::witness::LutTableSpec; use neo_memory::witness::{StepInstanceBundle, StepWitnessBundle}; @@ -52,6 +55,7 @@ use crate::shard::{self, CommitMixers, ShardProof as FoldRun, ShardProverContext use crate::PiCcsError; use neo_reductions::engines::optimized_engine::oracle::SparseCache; use neo_reductions::engines::utils; +use neo_vm_trace::VmTrace; #[inline] fn mode_uses_sparse_cache(mode: &FoldingMode) -> bool { @@ -344,7 +348,7 @@ fn indices_from_spec(spec: &StepSpec) -> Vec { /// consistency with the protocol engine (which expects padded y-vectors). pub fn me_from_z_balanced>( params: &NeoParams, - s: &CcsStructure, // should be identity-first + s: &CcsStructure, // rectangular or square CCS l: &Lm, z: &[F], r: &[K], @@ -459,7 +463,7 @@ pub fn me_from_z_balanced>( /// as specified by StepSpec indices. pub fn me_from_z_balanced_select>( params: &NeoParams, - s: &CcsStructure, // should be identity-first + s: &CcsStructure, // rectangular or square CCS l: &Lm, z: &[F], r: &[K], @@ -790,7 +794,7 @@ where /// Access the collected *public* per-step bundles (MCS + optional Twist/Shout instances). /// - /// This is useful for specialized verifiers (e.g. RV32 B1 statement checks) that need access + /// This is useful for specialized verifiers that need access /// to memory/lookup instances, not just the MCS list. pub fn steps_public(&self) -> Vec> { self.steps.iter().map(StepInstanceBundle::from).collect() @@ -801,6 +805,10 @@ where &self.steps } + pub fn steps_witness_mut(&mut self) -> &mut [StepWitnessBundle] { + &mut self.steps + } + /// Access auxiliary data captured during the most recent shared-CPU-bus witness build (if any). pub fn shared_bus_aux(&self) -> Option<&ShardWitnessAux> { self.shared_bus_aux.as_ref() @@ -1064,6 +1072,42 @@ where Ok(()) } + /// Add shared-CPU-bus step bundles from an already-executed trace. + /// + /// This avoids re-running `trace_program` when the caller already has a `VmTrace`. + pub fn execute_shard_shared_cpu_bus_from_trace( + &mut self, + trace: &VmTrace, + max_steps: usize, + chunk_size: usize, + mem_layouts: &HashMap, + lut_tables: &HashMap>, + lut_table_specs: &HashMap, + lut_lanes: &HashMap, + initial_mem: &HashMap<(u32, u64), F>, + cpu_arith: &A, + ) -> Result<(), PiCcsError> + where + A: CpuArithmetization, + { + let (bundles, aux) = build_shard_witness_shared_cpu_bus_from_trace_with_aux( + trace, + max_steps, + chunk_size, + mem_layouts, + lut_tables, + lut_table_specs, + lut_lanes, + initial_mem, + cpu_arith, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared-bus witness build failed: {e:?}")))?; + + self.add_step_bundles(bundles); + self.shared_bus_aux = Some(aux); + Ok(()) + } + /// Check if any steps have Twist (memory) instances. pub fn has_twist_instances(&self) -> bool { self.steps.iter().any(|s| !s.mem_instances.is_empty()) @@ -1121,6 +1165,22 @@ where return Ok(s); } + // Shared CPU bus is the only supported Route-A witness format. + let step0 = &self.steps[0]; + let is_shared_bus = step0 + .mem_instances + .iter() + .all(|(inst, wit)| inst.comms.is_empty() && wit.mats.is_empty()) + && step0 + .lut_instances + .iter() + .all(|(inst, wit)| inst.comms.is_empty() && wit.mats.is_empty()); + if !is_shared_bus { + return Err(PiCcsError::InvalidInput( + "legacy no-shared CPU bus witness format was removed; use shared-bus witness bundles".into(), + )); + } + let steps_public: Vec> = self.steps.iter().map(StepInstanceBundle::from).collect(); let (s_prepared, _cpu_bus) = @@ -1669,10 +1729,16 @@ where } }; - // For CCS-only sessions (no Twist/Shout), val-lane obligations should be empty - // For Twist+Shout sessions, val-lane obligations are expected and valid + // Val-lane obligations are expected when the session carries any sidecar val lane: + // Twist/Shout folds, or WB/WP folds over RV32 trace openings. let has_twist_or_shout = self.has_twist_instances() || self.has_shout_instances(); - if !has_twist_or_shout && !outputs.obligations.val.is_empty() { + let has_wb_or_wp = run.steps.iter().any(|step| { + !step.mem.wb_me_claims.is_empty() + || !step.mem.wp_me_claims.is_empty() + || !step.wb_fold.is_empty() + || !step.wp_fold.is_empty() + }); + if !(has_twist_or_shout || has_wb_or_wp) && !outputs.obligations.val.is_empty() { return Err(PiCcsError::ProtocolError( "CCS-only session verification produced unexpected val-lane obligations".into(), )); @@ -1838,7 +1904,13 @@ where }; let has_twist_or_shout = self.has_twist_instances() || self.has_shout_instances(); - if !has_twist_or_shout && !outputs.obligations.val.is_empty() { + let has_wb_or_wp = run.steps.iter().any(|step| { + !step.mem.wb_me_claims.is_empty() + || !step.mem.wp_me_claims.is_empty() + || !step.wb_fold.is_empty() + || !step.wp_fold.is_empty() + }); + if !(has_twist_or_shout || has_wb_or_wp) && !outputs.obligations.val.is_empty() { return Err(PiCcsError::ProtocolError( "CCS-only session verification produced unexpected val-lane obligations".into(), )); diff --git a/crates/neo-fold/src/session/ccs_builder.rs b/crates/neo-fold/src/session/ccs_builder.rs index 94e7e5cd..b374d81f 100644 --- a/crates/neo-fold/src/session/ccs_builder.rs +++ b/crates/neo-fold/src/session/ccs_builder.rs @@ -154,28 +154,12 @@ where ], ); - // Match `neo_ccs::r1cs_to_ccs` behavior: insert identity-first only when square. - let (matrices, f) = if n == m { - ( - vec![ - CcsMatrix::Identity { n }, - CcsMatrix::Csc(CscMat::from_triplets(a_trips, n, m)), - CcsMatrix::Csc(CscMat::from_triplets(b_trips, n, m)), - CcsMatrix::Csc(CscMat::from_triplets(c_trips, n, m)), - ], - f_base.insert_var_at_front(), - ) - } else { - ( - vec![ - CcsMatrix::Csc(CscMat::from_triplets(a_trips, n, m)), - CcsMatrix::Csc(CscMat::from_triplets(b_trips, n, m)), - CcsMatrix::Csc(CscMat::from_triplets(c_trips, n, m)), - ], - f_base, - ) - }; - - CcsStructure::new_sparse(matrices, f).map_err(|e| format!("CcsBuilder: invalid CCS: {e:?}")) + let matrices = vec![ + CcsMatrix::Csc(CscMat::from_triplets(a_trips, n, m)), + CcsMatrix::Csc(CscMat::from_triplets(b_trips, n, m)), + CcsMatrix::Csc(CscMat::from_triplets(c_trips, n, m)), + ]; + + CcsStructure::new_sparse(matrices, f_base).map_err(|e| format!("CcsBuilder: invalid CCS: {e:?}")) } } diff --git a/crates/neo-fold/src/session/circuit.rs b/crates/neo-fold/src/session/circuit.rs index b1f35d09..6b8e8cca 100644 --- a/crates/neo-fold/src/session/circuit.rs +++ b/crates/neo-fold/src/session/circuit.rs @@ -98,6 +98,7 @@ impl SharedBusR1csPreprocessing { &self.resources.lut_table_specs, chunk_to_witness, ) + .map_err(|e| PiCcsError::InvalidInput(format!("R1csCpu::new failed: {e}")))? .with_shared_cpu_bus( SharedCpuBusConfig { mem_layouts: self.resources.mem_layouts.clone(), @@ -105,6 +106,8 @@ impl SharedBusR1csPreprocessing { const_one_col: self.const_one_col, shout_cpu: self.shout_cpu.clone(), twist_cpu: self.twist_cpu.clone(), + shout_addr_groups: HashMap::new(), + shout_selector_groups: HashMap::new(), }, self.chunk_size, ) @@ -200,8 +203,8 @@ where // Ensure Shout lane counts are consistent across resources + cpu bindings. // - // If lanes aren't set explicitly in resources, infer them from the binding vector length - // so witness building + bus layout inference remain consistent. + // Empty/missing shout_cpu bindings mean "padding-only" and imply one bus lane. + // If lanes aren't set explicitly in resources, infer them from binding len with this rule. { let mut table_ids: Vec = resources .lut_tables @@ -213,26 +216,20 @@ where table_ids.dedup(); for table_id in table_ids { - let bindings = shout_cpu.get(&table_id).ok_or_else(|| { - PiCcsError::InvalidInput(format!("missing shout_cpu binding for table_id={table_id}")) - })?; - if bindings.is_empty() { - return Err(PiCcsError::InvalidInput(format!( - "shout_cpu bindings for table_id={table_id} must be non-empty" - ))); - } + let bindings = shout_cpu.get(&table_id).map(Vec::as_slice).unwrap_or(&[]); + let inferred_lanes = bindings.len().max(1); match resources.lut_lanes.get(&table_id) { Some(&lanes) => { - if lanes.max(1) != bindings.len() { + if lanes.max(1) != inferred_lanes { return Err(PiCcsError::InvalidInput(format!( - "shout lanes mismatch for table_id={table_id}: resources.lut_lanes={} but cpu_bindings provides {}", + "shout lanes mismatch for table_id={table_id}: resources.lut_lanes={} but cpu_bindings implies {} lane(s) (empty/missing means 1 padding-only lane)", lanes, - bindings.len() + inferred_lanes ))); } } None => { - resources.lut_lanes.insert(table_id, bindings.len()); + resources.lut_lanes.insert(table_id, inferred_lanes); } } } @@ -292,6 +289,12 @@ fn shout_meta_for_bus( .ok_or_else(|| "2*xlen overflow for RISC-V shout table".to_string())?; Ok((d, 2usize)) } + LutTableSpec::RiscvOpcodePacked { .. } => { + Err("RiscvOpcodePacked is not supported in shared-bus circuits".into()) + } + LutTableSpec::RiscvOpcodeEventTablePacked { .. } => { + Err("RiscvOpcodeEventTablePacked is not supported in shared-bus circuits".into()) + } LutTableSpec::IdentityU32 => Ok((32usize, 2usize)), } } else { @@ -332,20 +335,16 @@ fn shared_bus_buslen_and_constraints( let ell_addr = d .checked_mul(ell) .ok_or_else(|| format!("ell_addr overflow for shout table_id={table_id}"))?; - let lanes = resources - .lut_lanes - .get(table_id) - .copied() - .unwrap_or(1) - .max(1); - let bindings = shout_cpu - .get(table_id) - .ok_or_else(|| format!("missing shout_cpu binding for table_id={table_id}"))?; - if bindings.len() != lanes { - return Err(format!( - "shout_cpu bindings for table_id={table_id} has len={}, expected lanes={lanes}", - bindings.len() - )); + let bindings = shout_cpu.get(table_id).map(Vec::as_slice).unwrap_or(&[]); + let lanes = bindings.len().max(1); + if let Some(&declared_lanes) = resources.lut_lanes.get(table_id) { + if declared_lanes.max(1) != lanes { + return Err(format!( + "shout lanes mismatch for table_id={table_id}: resources.lut_lanes={} but cpu_bindings implies {} lane(s) (empty/missing means 1 padding-only lane)", + declared_lanes, + lanes + )); + } } shout_ell_addrs_and_lanes.push((ell_addr, lanes)); } @@ -397,10 +396,14 @@ fn shared_bus_buslen_and_constraints( let mut builder = CpuConstraintBuilder::::new(/*n=*/ 1, /*m=*/ m_min, const_one_col); for (i, table_id) in table_ids.iter().enumerate() { - let cpus = shout_cpu - .get(table_id) - .ok_or_else(|| format!("missing shout_cpu binding for table_id={table_id}"))?; + let cpus = shout_cpu.get(table_id).map(Vec::as_slice).unwrap_or(&[]); let inst_cols = &bus_layout.shout_cols[i]; + if cpus.is_empty() { + for lane_cols in &inst_cols.lanes { + builder.add_shout_instance_padding(&bus_layout, lane_cols); + } + continue; + } if cpus.len() != inst_cols.lanes.len() { return Err(format!( "shared-bus shout lanes mismatch for table_id={table_id}: shout_cpu has len={} but bus layout expects {}", diff --git a/crates/neo-fold/src/shard.rs b/crates/neo-fold/src/shard.rs index 035bd1bb..cc810e94 100644 --- a/crates/neo-fold/src/shard.rs +++ b/crates/neo-fold/src/shard.rs @@ -32,8 +32,9 @@ use neo_ajtai::{ use neo_ccs::traits::SModuleHomomorphism; use neo_ccs::{CcsStructure, Mat, MeInstance}; use neo_math::{KExtensions, D, F, K}; +use neo_memory::riscv::trace::{Rv32DecodeSidecarLayout, Rv32TraceLayout}; use neo_memory::ts_common as ts; -use neo_memory::witness::{StepInstanceBundle, StepWitnessBundle}; +use neo_memory::witness::{LutTableSpec, StepInstanceBundle, StepWitnessBundle}; use neo_params::NeoParams; use neo_reductions::engines::optimized_engine::oracle::SparseCache; use neo_reductions::engines::utils; @@ -79,3888 +80,18 @@ fn elapsed_ms(start: TimePoint) -> f64 { } } -enum CcsOracleDispatch<'a> { - Optimized(neo_reductions::engines::optimized_engine::oracle::OptimizedOracle<'a, F>), - #[cfg(feature = "paper-exact")] - PaperExact(neo_reductions::engines::paper_exact_engine::oracle::PaperExactOracle<'a, F>), -} - -impl<'a> RoundOracle for CcsOracleDispatch<'a> { - fn evals_at(&mut self, points: &[K]) -> Vec { - match self { - Self::Optimized(oracle) => oracle.evals_at(points), - #[cfg(feature = "paper-exact")] - Self::PaperExact(oracle) => oracle.evals_at(points), - } - } - - fn num_rounds(&self) -> usize { - match self { - Self::Optimized(oracle) => oracle.num_rounds(), - #[cfg(feature = "paper-exact")] - Self::PaperExact(oracle) => oracle.num_rounds(), - } - } - - fn degree_bound(&self) -> usize { - match self { - Self::Optimized(oracle) => oracle.degree_bound(), - #[cfg(feature = "paper-exact")] - Self::PaperExact(oracle) => oracle.degree_bound(), - } - } - - fn fold(&mut self, r: K) { - match self { - Self::Optimized(oracle) => oracle.fold(r), - #[cfg(feature = "paper-exact")] - Self::PaperExact(oracle) => oracle.fold(r), - } - } -} - -// ============================================================================ -// Utilities -// ============================================================================ - -pub use crate::memory_sidecar::memory::absorb_step_memory; - -// ============================================================================ -// Optional step-to-step (cross-chunk) linking -// ============================================================================ - -/// Optional verifier-side linking constraints across adjacent shard steps. -/// -/// This is intended for chunked CPU circuits that expose boundary state as part of the public -/// input vector `x` per step, and need the verifier to enforce that the state chains across steps. -#[derive(Clone, Debug)] -pub struct StepLinkingConfig { - /// Equalities on adjacent steps: require `steps[i].x[prev_idx] == steps[i+1].x[next_idx]`. - pub prev_next_equalities: Vec<(usize, usize)>, -} - -impl StepLinkingConfig { - pub fn new(prev_next_equalities: Vec<(usize, usize)>) -> Self { - Self { prev_next_equalities } - } -} - -pub fn check_step_linking(steps: &[StepInstanceBundle], cfg: &StepLinkingConfig) -> Result<(), PiCcsError> { - if steps.len() <= 1 || cfg.prev_next_equalities.is_empty() { - return Ok(()); - } - for (i, (prev, next)) in steps.iter().zip(steps.iter().skip(1)).enumerate() { - let prev_x = &prev.mcs_inst.x; - let next_x = &next.mcs_inst.x; - for &(prev_idx, next_idx) in &cfg.prev_next_equalities { - if prev_idx >= prev_x.len() || next_idx >= next_x.len() { - return Err(PiCcsError::InvalidInput(format!( - "step linking index out of range at boundary {i}: prev_x.len()={}, next_x.len()={}, pair=({prev_idx},{next_idx})", - prev_x.len(), - next_x.len(), - ))); - } - if prev_x[prev_idx] != next_x[next_idx] { - return Err(PiCcsError::ProtocolError(format!( - "step linking failed at boundary {i}: prev_x[{prev_idx}] != next_x[{next_idx}]", - ))); - } - } - } - Ok(()) -} - -/// Commitment mixers so the coordinator stays scheme-agnostic. -/// - `mix_rhos_commits(ρ, cs)` returns Σ ρ_i · c_i (S-action). -/// - `combine_b_pows(cs, b)` returns Σ \bar b^{i-1} c_i (DEC check). -#[derive(Clone, Copy)] -pub struct CommitMixers -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt, - MB: Fn(&[Cmt], u32) -> Cmt, -{ - pub mix_rhos_commits: MR, - pub combine_b_pows: MB, -} - -pub fn normalize_me_claims( - me_claims: &mut [MeInstance], - ell_n: usize, - ell_d: usize, - t: usize, -) -> Result<(), PiCcsError> { - let y_pad = 1usize << ell_d; - for (i, me) in me_claims.iter_mut().enumerate() { - if me.r.len() != ell_n { - return Err(PiCcsError::InvalidInput(format!( - "ME[{}] r.len()={}, expected ell_n={}", - i, - me.r.len(), - ell_n - ))); - } - if me.y.len() > t { - return Err(PiCcsError::InvalidInput(format!( - "ME[{}] y.len()={}, expected <= t={}", - i, - me.y.len(), - t - ))); - } - for (j, row) in me.y.iter_mut().enumerate() { - if row.len() > y_pad { - return Err(PiCcsError::InvalidInput(format!( - "ME[{}] y[{}].len()={}, expected <= {}", - i, - j, - row.len(), - y_pad - ))); - } - row.resize(y_pad, K::ZERO); - } - me.y.resize_with(t, || vec![K::ZERO; y_pad]); - if me.y_scalars.len() > t { - return Err(PiCcsError::InvalidInput(format!( - "ME[{}] y_scalars.len()={}, expected <= t={}", - i, - me.y_scalars.len(), - t - ))); - } - me.y_scalars.resize(t, K::ZERO); - } - Ok(()) -} - -fn validate_me_batch_invariants(batch: &[MeInstance], context: &str) -> Result<(), PiCcsError> { - if batch.is_empty() { - return Ok(()); - } - let me0 = &batch[0]; - let r0 = &me0.r; - let m_in0 = me0.m_in; - let y_len0 = me0.y.len(); - let y_row_len0 = me0.y.first().map(|r| r.len()).unwrap_or(0); - let y_scalars_len0 = me0.y_scalars.len(); - - if me0.X.rows() != D { - return Err(PiCcsError::ProtocolError(format!( - "{}: ME claim 0 has X.rows()={}, expected D={}", - context, - me0.X.rows(), - D - ))); - } - if me0.X.cols() != m_in0 { - return Err(PiCcsError::ProtocolError(format!( - "{}: ME claim 0 has X.cols()={}, expected m_in={}", - context, - me0.X.cols(), - m_in0 - ))); - } - - for (i, me) in batch.iter().enumerate().skip(1) { - if me.r != *r0 { - return Err(PiCcsError::ProtocolError(format!( - "{}: ME claim {} has different r than claim 0 (r-alignment required for RLC)", - context, i - ))); - } - if me.m_in != m_in0 { - return Err(PiCcsError::ProtocolError(format!( - "{}: ME claim {} has m_in={}, expected {}", - context, i, me.m_in, m_in0 - ))); - } - if me.X.rows() != D || me.X.cols() != m_in0 { - return Err(PiCcsError::ProtocolError(format!( - "{}: ME claim {} has X shape {}x{}, expected {}x{}", - context, - i, - me.X.rows(), - me.X.cols(), - D, - m_in0 - ))); - } - if me.y.len() != y_len0 { - return Err(PiCcsError::ProtocolError(format!( - "{}: ME claim {} has y.len()={}, expected {}", - context, - i, - me.y.len(), - y_len0 - ))); - } - for (j, row) in me.y.iter().enumerate() { - if row.len() != y_row_len0 { - return Err(PiCcsError::ProtocolError(format!( - "{}: ME claim {} has y[{}].len()={}, expected {}", - context, - i, - j, - row.len(), - y_row_len0 - ))); - } - } - if me.y_scalars.len() != y_scalars_len0 { - return Err(PiCcsError::ProtocolError(format!( - "{}: ME claim {} has y_scalars.len()={}, expected {}", - context, - i, - me.y_scalars.len(), - y_scalars_len0 - ))); - } - } - Ok(()) -} - -#[derive(Clone, Copy, Debug)] -enum RlcLane { - Main, - Val, -} - -#[inline] -fn balanced_divrem_i64(v: i64, b: i64) -> (i64, i64) { - debug_assert!(b >= 2); - let mut r = v % b; - let mut q = (v - r) / b; - let half = b / 2; - if r > half { - r -= b; - q += 1; - } else if r < -half { - r += b; - q -= 1; - } - (r, q) -} - -#[inline] -fn balanced_divrem_i128(v: i128, b: i128) -> (i128, i128) { - debug_assert!(b >= 2); - let mut r = v % b; - let mut q = (v - r) / b; - let half = b / 2; - if r > half { - r -= b; - q += 1; - } else if r < -half { - r += b; - q -= 1; - } - (r, q) -} - -#[inline] -fn f_from_i64(x: i64) -> F { - if x >= 0 { - F::from_u64(x as u64) - } else { - F::ZERO - F::from_u64((-x) as u64) - } -} - -fn dec_stream_no_witness( - params: &NeoParams, - s: &CcsStructure, - parent: &MeInstance, - Z_mix: &Mat, - ell_d: usize, - k_dec: usize, - combine_b_pows: MB, - sparse: Option<&SparseCache>, -) -> Result<(Vec>, Vec, bool, bool, bool), PiCcsError> -where - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - if k_dec == 0 { - return Err(PiCcsError::InvalidInput("DEC: k_dec must be > 0".into())); - } - if Z_mix.rows() != D || Z_mix.cols() != s.m { - return Err(PiCcsError::InvalidInput(format!( - "DEC: Z_mix must have shape D×m = {}×{} (got {}×{})", - D, - s.m, - Z_mix.rows(), - Z_mix.cols() - ))); - } - - let d_pad = 1usize << ell_d; - let want_nc_channel = !(parent.s_col.is_empty() && parent.y_zcol.is_empty()); - if want_nc_channel && (parent.s_col.is_empty() || parent.y_zcol.is_empty()) { - return Err(PiCcsError::InvalidInput( - "DEC: incomplete NC channel on parent (expected both s_col and y_zcol)".into(), - )); - } - if want_nc_channel && parent.y_zcol.len() != d_pad { - return Err(PiCcsError::InvalidInput(format!( - "DEC: parent y_zcol length mismatch (expected {}, got {})", - d_pad, - parent.y_zcol.len() - ))); - } - - enum PpAccess { - Seeded { - kappa: usize, - chunk_size: usize, - chunk_seeds_by_row: Vec>, - }, - Loaded { - pp: Arc>, - }, - } - - let pp_access = if let Some(pp) = try_get_loaded_global_pp_for_dims(D, s.m) { - if pp.kappa == 0 { - return Err(PiCcsError::InvalidInput("DEC: PP.kappa must be > 0".into())); - } - PpAccess::Loaded { pp } - } else if let Ok((kappa, seed)) = get_global_pp_seeded_params_for_dims(D, s.m) { - if kappa == 0 { - return Err(PiCcsError::InvalidInput("DEC: PP.kappa must be > 0".into())); - } - let (chunk_size, chunk_seeds_by_row) = seeded_pp_chunk_seeds(seed, kappa, s.m); - PpAccess::Seeded { - kappa, - chunk_size, - chunk_seeds_by_row, - } - } else { - // Fallback: non-seeded entry. This will materialize PP if needed. - let pp = get_global_pp_for_dims(D, s.m).map_err(|e| { - PiCcsError::InvalidInput(format!("DEC: Ajtai PP unavailable for (d,m)=({},{}) ({})", D, s.m, e)) - })?; - if pp.kappa == 0 { - return Err(PiCcsError::InvalidInput("DEC: PP.kappa must be > 0".into())); - } - PpAccess::Loaded { pp } - }; - - // Build χ_r and v_j = M_j^T · χ_r (same as the reference DEC). - let ell_n = parent.r.len(); - let n_sz = 1usize - .checked_shl(ell_n as u32) - .ok_or_else(|| PiCcsError::InvalidInput("DEC: 2^ell_n overflow".into()))?; - let n_eff = core::cmp::min(s.n, n_sz); - - // χ_r table over the row/time hypercube. - // - // IMPORTANT: Use the same bit order as `eq_points_bool_mask` / `chi_tail_weights` - // (bit 0 = LSB) so CSC column traversals match the reference DEC. - #[inline] - fn chi_tail_weights(bits: &[K]) -> Vec { - let t = bits.len(); - let len = 1usize << t; - let mut w = vec![K::ZERO; len]; - w[0] = K::ONE; - for (i, &b) in bits.iter().enumerate() { - let step = 1usize << i; - let one_minus = K::ONE - b; - for mask in 0..step { - let v = w[mask]; - w[mask] = v * one_minus; - w[mask + step] = v * b; - } - } - w - } - - let chi_r = chi_tail_weights(&parent.r); - debug_assert_eq!(chi_r.len(), n_sz); - - let chi_s = if want_nc_channel { - let chi = chi_tail_weights(&parent.s_col); - if chi.len() < s.m { - return Err(PiCcsError::InvalidInput(format!( - "DEC: chi(s_col) too short for CCS width (need >= {}, got {})", - s.m, - chi.len() - ))); - } - chi - } else { - Vec::new() - }; - - let t_mats = s.t(); - - enum VjsAccess<'a> { - Dense(Vec>), - Sparse { - cap: usize, - cache: &'a SparseCache, - }, - } - - let vjs_access = if let Some(cache) = sparse { - if cache.len() != t_mats { - return Err(PiCcsError::InvalidInput(format!( - "DEC: sparse cache matrix count mismatch: got {}, expected {}", - cache.len(), - t_mats - ))); - } - let cap = core::cmp::min(s.m, n_eff); - VjsAccess::Sparse { cap, cache } - } else { - let mut vjs: Vec> = vec![vec![K::ZERO; s.m]; t_mats]; - for j in 0..t_mats { - s.matrices[j].add_mul_transpose_into(&chi_r, &mut vjs[j], n_eff); - } - VjsAccess::Dense(vjs) - }; - - // Base-b powers in K for y_scalar recomposition. - let bF = F::from_u64(params.b as u64); - let bK = K::from(bF); - let mut pow_b_k = [K::ONE; D]; - for rho in 1..D { - pow_b_k[rho] = pow_b_k[rho - 1] * bK; - } - - // Precompute parameters for bounded signed decoding of Z_mix entries. - let b_u = params.b as u128; - let mut B_u: u128 = 1; - for _ in 0..k_dec { - B_u = B_u.saturating_mul(b_u); - } - let p: u128 = F::ORDER_U64 as u128; - - // Fast row-major access. - let z_rows: Vec<&[F]> = (0..D).map(|r| Z_mix.row(r)).collect(); - - struct Acc { - commit: Vec<[F; D]>, // [digit][kappa] -> [D] - y: Vec<[K; D]>, // [digit][t] -> [D] - y_zcol: Vec<[K; D]>, // [digit] -> [D] - any_nonzero: Vec, - vj: Vec, // scratch: t - digits: Vec, // scratch: k*D (balanced digits) - rot_next: [F; D], // scratch: rotation step output (written fully each time) - err: Option, // first error wins - } - - impl Acc { - fn new(k_dec: usize, kappa: usize, t: usize) -> Self { - Self { - commit: vec![[F::ZERO; D]; k_dec * kappa], - y: vec![[K::ZERO; D]; k_dec * t], - y_zcol: vec![[K::ZERO; D]; k_dec], - any_nonzero: vec![false; k_dec], - vj: vec![K::ZERO; t], - digits: vec![0i32; k_dec * D], - rot_next: [F::ZERO; D], - err: None, - } - } - - #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] - fn add_inplace(&mut self, rhs: &Acc, k_dec: usize, kappa: usize, t: usize) { - for (dst, src) in self.commit.iter_mut().zip(rhs.commit.iter()) { - for r in 0..D { - dst[r] += src[r]; - } - } - for (dst, src) in self.y.iter_mut().zip(rhs.y.iter()) { - for r in 0..D { - dst[r] += src[r]; - } - } - for (dst, src) in self.y_zcol.iter_mut().zip(rhs.y_zcol.iter()) { - for r in 0..D { - dst[r] += src[r]; - } - } - for i in 0..k_dec { - self.any_nonzero[i] |= rhs.any_nonzero[i]; - } - if self.err.is_none() { - self.err = rhs.err.clone(); - } - // silence unused warnings when parameters are const-propagated - let _ = (k_dec, kappa, t); - } - } - - let m = s.m; - let b_i64 = params.b as i64; - let b_i128 = params.b as i128; - - // Specialized rot_step for Φ₈₁(X) = X^54 + X^27 + 1 (η=81, D=54). - // Mirrors `neo_ajtai::commit::rot_step_phi_81` but kept local to avoid pulling a large - // D×D scratch table (`precompute_rot_columns`) into the hot DEC streaming loop. - #[inline] - fn rot_step_phi_81(cur: &[F; D], next: &mut [F; D]) { - let last = cur[D - 1]; - next[0] = F::ZERO; - next[1..D].copy_from_slice(&cur[..(D - 1)]); - next[0] -= last; - next[27] -= last; - } - - #[inline] - fn acc_add_assign(acc: &mut [F; D], col: &[F; D]) { - type P = ::Packing; - let prefix_len = D - (D % P::WIDTH); - let (acc_prefix, acc_suffix) = acc.split_at_mut(prefix_len); - let (col_prefix, col_suffix) = col.split_at(prefix_len); - - for (a, b) in P::pack_slice_mut(acc_prefix) - .iter_mut() - .zip(P::pack_slice(col_prefix).iter()) - { - *a += *b; - } - for (a, &b) in acc_suffix.iter_mut().zip(col_suffix.iter()) { - *a += b; - } - } - - #[inline] - fn acc_sub_assign(acc: &mut [F; D], col: &[F; D]) { - type P = ::Packing; - let prefix_len = D - (D % P::WIDTH); - let (acc_prefix, acc_suffix) = acc.split_at_mut(prefix_len); - let (col_prefix, col_suffix) = col.split_at(prefix_len); - - for (a, b) in P::pack_slice_mut(acc_prefix) - .iter_mut() - .zip(P::pack_slice(col_prefix).iter()) - { - *a -= *b; - } - for (a, &b) in acc_suffix.iter_mut().zip(col_suffix.iter()) { - *a -= b; - } - } - - #[inline] - fn acc_mul_add_assign(acc: &mut [F; D], col: &[F; D], scalar: F) { - type P = ::Packing; - let prefix_len = D - (D % P::WIDTH); - let (acc_prefix, acc_suffix) = acc.split_at_mut(prefix_len); - let (col_prefix, col_suffix) = col.split_at(prefix_len); - let scalar_p: P = scalar.into(); - - for (a, b) in P::pack_slice_mut(acc_prefix) - .iter_mut() - .zip(P::pack_slice(col_prefix).iter()) - { - *a += *b * scalar_p; - } - for (a, &b) in acc_suffix.iter_mut().zip(col_suffix.iter()) { - *a += b * scalar; - } - } - - let (kappa, acc) = match &pp_access { - PpAccess::Loaded { pp } => { - let kappa = pp.kappa; - let process_col = |mut st: Acc, col: usize| -> Acc { - if st.err.is_some() { - return st; - } +#[path = "shard/core_utils.rs"] +mod core_utils; +#[path = "shard/prover.rs"] +mod prover; +#[path = "shard/rlc_dec.rs"] +mod rlc_dec; +#[path = "shard/verifier_and_api.rs"] +mod verifier_and_api; - // Decompose the column's D entries into balanced base-b digits for each DEC child. - for rho in 0..D { - let u = z_rows[rho][col].as_canonical_u64() as u128; - if B_u <= i64::MAX as u128 { - let val_opt: Option = if u < B_u { - Some(u as i64) - } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { - Some(-((p - u) as i64)) - } else { - None - }; - let mut v = match val_opt { - Some(v) => v, - None => { - st.err = Some(format!( - "DEC split: Z_mix[{},{}] is out of range for k_rho={}, b={}", - rho, col, k_dec, params.b - )); - return st; - } - }; - for i in 0..k_dec { - if v == 0 { - st.digits[i * D + rho] = 0; - continue; - } - let (r_i, q) = balanced_divrem_i64(v, b_i64); - if r_i != 0 { - st.any_nonzero[i] = true; - } - st.digits[i * D + rho] = r_i as i32; - v = q; - } - if v != 0 { - st.err = Some(format!( - "DEC split: Z_mix[{},{}] needs more than k_rho={} digits in base b={}", - rho, col, k_dec, params.b - )); - return st; - } - } else { - let val_opt: Option = if u < B_u { - Some(u as i128) - } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { - Some(-((p - u) as i128)) - } else { - None - }; - let mut v = match val_opt { - Some(v) => v, - None => { - st.err = Some(format!( - "DEC split: Z_mix[{},{}] is out of range for k_rho={}, b={}", - rho, col, k_dec, params.b - )); - return st; - } - }; - for i in 0..k_dec { - if v == 0 { - st.digits[i * D + rho] = 0; - continue; - } - let (r_i, q) = balanced_divrem_i128(v, b_i128); - if r_i != 0 { - st.any_nonzero[i] = true; - } - st.digits[i * D + rho] = r_i as i32; - v = q; - } - if v != 0 { - st.err = Some(format!( - "DEC split: Z_mix[{},{}] needs more than k_rho={} digits in base b={}", - rho, col, k_dec, params.b - )); - return st; - } - } - } - - // vj[col] := M_j^T · χ_r (compute per column to avoid materializing all vjs). - match &vjs_access { - VjsAccess::Dense(vjs) => { - for j in 0..t_mats { - st.vj[j] = vjs[j][col]; - } - } - VjsAccess::Sparse { cap, cache } => { - for j in 0..t_mats { - st.vj[j] = if let Some(csc) = cache.csc(j) { - let mut sum = K::ZERO; - let s = csc.col_ptr[col]; - let e = csc.col_ptr[col + 1]; - for k in s..e { - let r = csc.row_idx[k]; - if r < n_eff { - sum += K::from(csc.vals[k]) * chi_r[r]; - } - } - sum - } else if col < *cap { - chi_r[col] - } else { - K::ZERO - }; - } - } - } - - // y_(i,j)[rho] += Z_i[rho,col] * vj[col] - for i in 0..k_dec { - let y_base = i * t_mats; - for rho in 0..D { - let digit = st.digits[i * D + rho]; - if digit == 0 { - continue; - } - for j in 0..t_mats { - let vj = st.vj[j]; - if vj != K::ZERO { - match digit { - 1 => st.y[y_base + j][rho] += vj, - -1 => st.y[y_base + j][rho] -= vj, - _ => st.y[y_base + j][rho] += vj.scale_base(f_from_i64(digit as i64)), - } - } - } - } - } - - // y_zcol_i[rho] += Z_i[rho,col] * χ_{s_col}[col] (optional). - if !chi_s.is_empty() { - let w_col = chi_s[col]; - if w_col != K::ZERO { - for i in 0..k_dec { - for rho in 0..D { - let digit = st.digits[i * D + rho]; - if digit == 0 { - continue; - } - match digit { - 1 => st.y_zcol[i][rho] += w_col, - -1 => st.y_zcol[i][rho] -= w_col, - _ => st.y_zcol[i][rho] += w_col.scale_base(f_from_i64(digit as i64)), - } - } - } - } - } - - // Commitment accumulators per digit. - for kr in 0..kappa { - let mut rot_col = neo_math::ring::cf(pp.m_rows[kr][col]); - for rho in 0..D { - for i in 0..k_dec { - let digit = st.digits[i * D + rho]; - if digit == 0 { - continue; - } - let acc = &mut st.commit[i * kappa + kr]; - match digit { - 1 => acc_add_assign(acc, &rot_col), - -1 => acc_sub_assign(acc, &rot_col), - _ => acc_mul_add_assign(acc, &rot_col, f_from_i64(digit as i64)), - } - } - rot_step_phi_81(&rot_col, &mut st.rot_next); - core::mem::swap(&mut rot_col, &mut st.rot_next); - } - } - - st - }; - - let acc = { - #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] - { - (0..m) - .into_par_iter() - .fold(|| Acc::new(k_dec, kappa, t_mats), |st, col| process_col(st, col)) - .reduce( - || Acc::new(k_dec, kappa, t_mats), - |mut a, b| { - if a.err.is_none() { - a.add_inplace(&b, k_dec, kappa, t_mats); - } - a - }, - ) - } - #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] - { - let mut st = Acc::new(k_dec, kappa, t_mats); - for col in 0..m { - st = process_col(st, col); - } - st - } - }; - (kappa, acc) - } - PpAccess::Seeded { - kappa, - chunk_size, - chunk_seeds_by_row, - } => { - let kappa = *kappa; - let chunk_size = *chunk_size; - let num_chunks = (m + chunk_size - 1) / chunk_size; - - let process_chunk = |mut st: Acc, chunk_idx: usize| -> Acc { - if st.err.is_some() { - return st; - } - - let start = chunk_idx * chunk_size; - let end = core::cmp::min(m, start + chunk_size); - - let mut rngs: Vec = (0..kappa) - .map(|kr| ChaCha8Rng::from_seed(chunk_seeds_by_row[kr][chunk_idx])) - .collect(); - - for col in start..end { - // Decompose the column's D entries into balanced base-b digits for each DEC child. - for rho in 0..D { - let u = z_rows[rho][col].as_canonical_u64() as u128; - if B_u <= i64::MAX as u128 { - let val_opt: Option = if u < B_u { - Some(u as i64) - } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { - Some(-((p - u) as i64)) - } else { - None - }; - let mut v = match val_opt { - Some(v) => v, - None => { - st.err = Some(format!( - "DEC split: Z_mix[{},{}] is out of range for k_rho={}, b={}", - rho, col, k_dec, params.b - )); - return st; - } - }; - for i in 0..k_dec { - if v == 0 { - st.digits[i * D + rho] = 0; - continue; - } - let (r_i, q) = balanced_divrem_i64(v, b_i64); - if r_i != 0 { - st.any_nonzero[i] = true; - } - st.digits[i * D + rho] = r_i as i32; - v = q; - } - if v != 0 { - st.err = Some(format!( - "DEC split: Z_mix[{},{}] needs more than k_rho={} digits in base b={}", - rho, col, k_dec, params.b - )); - return st; - } - } else { - let val_opt: Option = if u < B_u { - Some(u as i128) - } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { - Some(-((p - u) as i128)) - } else { - None - }; - let mut v = match val_opt { - Some(v) => v, - None => { - st.err = Some(format!( - "DEC split: Z_mix[{},{}] is out of range for k_rho={}, b={}", - rho, col, k_dec, params.b - )); - return st; - } - }; - for i in 0..k_dec { - if v == 0 { - st.digits[i * D + rho] = 0; - continue; - } - let (r_i, q) = balanced_divrem_i128(v, b_i128); - if r_i != 0 { - st.any_nonzero[i] = true; - } - st.digits[i * D + rho] = r_i as i32; - v = q; - } - if v != 0 { - st.err = Some(format!( - "DEC split: Z_mix[{},{}] needs more than k_rho={} digits in base b={}", - rho, col, k_dec, params.b - )); - return st; - } - } - } - - // vj[col] := M_j^T · χ_r (compute per column to avoid materializing all vjs). - match &vjs_access { - VjsAccess::Dense(vjs) => { - for j in 0..t_mats { - st.vj[j] = vjs[j][col]; - } - } - VjsAccess::Sparse { cap, cache } => { - for j in 0..t_mats { - st.vj[j] = if let Some(csc) = cache.csc(j) { - let mut sum = K::ZERO; - let s = csc.col_ptr[col]; - let e = csc.col_ptr[col + 1]; - for k in s..e { - let r = csc.row_idx[k]; - if r < n_eff { - sum += K::from(csc.vals[k]) * chi_r[r]; - } - } - sum - } else if col < *cap { - chi_r[col] - } else { - K::ZERO - }; - } - } - } - - // y_(i,j)[rho] += Z_i[rho,col] * vj[col] - for i in 0..k_dec { - let y_base = i * t_mats; - for rho in 0..D { - let digit = st.digits[i * D + rho]; - if digit == 0 { - continue; - } - for j in 0..t_mats { - let vj = st.vj[j]; - if vj != K::ZERO { - match digit { - 1 => st.y[y_base + j][rho] += vj, - -1 => st.y[y_base + j][rho] -= vj, - _ => st.y[y_base + j][rho] += vj.scale_base(f_from_i64(digit as i64)), - } - } - } - } - } - - // y_zcol_i[rho] += Z_i[rho,col] * χ_{s_col}[col] (optional). - if !chi_s.is_empty() { - let w_col = chi_s[col]; - if w_col != K::ZERO { - for i in 0..k_dec { - for rho in 0..D { - let digit = st.digits[i * D + rho]; - if digit == 0 { - continue; - } - match digit { - 1 => st.y_zcol[i][rho] += w_col, - -1 => st.y_zcol[i][rho] -= w_col, - _ => st.y_zcol[i][rho] += w_col.scale_base(f_from_i64(digit as i64)), - } - } - } - } - } - - // Commitment accumulators per digit. - for kr in 0..kappa { - let a_kr_col = sample_uniform_rq(&mut rngs[kr]); - let mut rot_col = neo_math::ring::cf(a_kr_col); - for rho in 0..D { - for i in 0..k_dec { - let digit = st.digits[i * D + rho]; - if digit == 0 { - continue; - } - let acc = &mut st.commit[i * kappa + kr]; - match digit { - 1 => acc_add_assign(acc, &rot_col), - -1 => acc_sub_assign(acc, &rot_col), - _ => acc_mul_add_assign(acc, &rot_col, f_from_i64(digit as i64)), - } - } - rot_step_phi_81(&rot_col, &mut st.rot_next); - core::mem::swap(&mut rot_col, &mut st.rot_next); - } - } - } - - st - }; - - let acc = { - #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] - { - (0..num_chunks) - .into_par_iter() - .fold( - || Acc::new(k_dec, kappa, t_mats), - |st, chunk_idx| process_chunk(st, chunk_idx), - ) - .reduce( - || Acc::new(k_dec, kappa, t_mats), - |mut a, b| { - if a.err.is_none() { - a.add_inplace(&b, k_dec, kappa, t_mats); - } - a - }, - ) - } - #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] - { - let mut st = Acc::new(k_dec, kappa, t_mats); - for chunk_idx in 0..num_chunks { - st = process_chunk(st, chunk_idx); - } - st - } - }; - (kappa, acc) - } - }; - - if let Some(err) = acc.err { - return Err(PiCcsError::ProtocolError(err)); - } - - // Commitments c_i from accumulated columns. - let mut child_cs: Vec = Vec::with_capacity(k_dec); - for i in 0..k_dec { - if !acc.any_nonzero[i] { - child_cs.push(Cmt::zeros(D, kappa)); - continue; - } - let mut c = Cmt::zeros(D, kappa); - for kr in 0..kappa { - c.col_mut(kr).copy_from_slice(&acc.commit[i * kappa + kr]); - } - child_cs.push(c); - } - - // X_i: project first m_in columns from Z_i (small; compute sequentially). - let m_in = parent.m_in; - let mut xs_row_major: Vec> = vec![vec![F::ZERO; D * m_in]; k_dec]; - for col in 0..m_in { - for rho in 0..D { - let u = z_rows[rho][col].as_canonical_u64() as u128; - if B_u <= i64::MAX as u128 { - let val_opt: Option = if u < B_u { - Some(u as i64) - } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { - Some(-((p - u) as i64)) - } else { - None - }; - let mut v = val_opt.ok_or_else(|| { - PiCcsError::ProtocolError(format!( - "DEC split(X): Z_mix[{},{}] out of range for k_rho={}, b={}", - rho, col, k_dec, params.b - )) - })?; - for i in 0..k_dec { - if v == 0 { - break; - } - let (r_i, q) = balanced_divrem_i64(v, b_i64); - xs_row_major[i][rho * m_in + col] = f_from_i64(r_i); - v = q; - } - if v != 0 { - return Err(PiCcsError::ProtocolError(format!( - "DEC split(X): Z_mix[{},{}] needs more than k_rho={} digits in base b={}", - rho, col, k_dec, params.b - ))); - } - } else { - let val_opt: Option = if u < B_u { - Some(u as i128) - } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { - Some(-((p - u) as i128)) - } else { - None - }; - let mut v = val_opt.ok_or_else(|| { - PiCcsError::ProtocolError(format!( - "DEC split(X): Z_mix[{},{}] out of range for k_rho={}, b={}", - rho, col, k_dec, params.b - )) - })?; - for i in 0..k_dec { - if v == 0 { - break; - } - let (r_i, q) = balanced_divrem_i128(v, b_i128); - xs_row_major[i][rho * m_in + col] = f_from_i64(r_i as i64); - v = q; - } - if v != 0 { - return Err(PiCcsError::ProtocolError(format!( - "DEC split(X): Z_mix[{},{}] needs more than k_rho={} digits in base b={}", - rho, col, k_dec, params.b - ))); - } - } - } - } - - let parent_r = parent.r.clone(); - let fold_digest = parent.fold_digest; - - let mut children: Vec> = Vec::with_capacity(k_dec); - for i in 0..k_dec { - let Xi = Mat::from_row_major(D, m_in, xs_row_major[i].clone()); - let mut y_i: Vec> = Vec::with_capacity(t_mats); - let mut y_scalars_i: Vec = Vec::with_capacity(t_mats); - for j in 0..t_mats { - let mut yj = vec![K::ZERO; d_pad]; - let row = &acc.y[i * t_mats + j]; - for rho in 0..D { - yj[rho] = row[rho]; - } - let mut sc = K::ZERO; - for rho in 0..D { - sc += yj[rho] * pow_b_k[rho]; - } - y_i.push(yj); - y_scalars_i.push(sc); - } - - let y_zcol = if chi_s.is_empty() { - Vec::new() - } else { - let mut yz = vec![K::ZERO; d_pad]; - let row = &acc.y_zcol[i]; - for rho in 0..D { - yz[rho] = row[rho]; - } - yz - }; - - children.push(MeInstance:: { - c_step_coords: vec![], - u_offset: 0, - u_len: 0, - c: child_cs[i].clone(), - X: Xi, - r: parent_r.clone(), - s_col: parent.s_col.clone(), - y: y_i, - y_scalars: y_scalars_i, - y_zcol, - m_in, - fold_digest, - }); - } - - // Public checks (mirror paper-exact DEC). - let mut ok_y = true; - for j in 0..t_mats { - let mut lhs = vec![K::ZERO; d_pad]; - let mut pow = K::ONE; - for i in 0..k_dec { - for t in 0..d_pad { - lhs[t] += pow * children[i].y[j][t]; - } - pow *= bK; - } - if lhs != parent.y[j] { - ok_y = false; - break; - } - } - - // y_zcol: column-domain opening must also decompose (when present). - if ok_y && !chi_s.is_empty() { - let mut lhs = vec![K::ZERO; d_pad]; - let mut pow = K::ONE; - for i in 0..k_dec { - for t in 0..d_pad { - lhs[t] += pow * children[i].y_zcol[t]; - } - pow *= bK; - } - if lhs != parent.y_zcol { - ok_y = false; - } - } - - let mut lhs_X = Mat::zero(D, m_in, F::ZERO); - let mut pow = F::ONE; - for i in 0..k_dec { - for r in 0..D { - for c in 0..m_in { - lhs_X[(r, c)] += pow * children[i].X[(r, c)]; - } - } - pow *= bF; - } - let ok_X = lhs_X.as_slice() == parent.X.as_slice(); - - let ok_c = combine_b_pows(&child_cs, params.b) == parent.c; - Ok((children, child_cs, ok_y, ok_X, ok_c)) -} - -fn bind_rlc_inputs( - tr: &mut Poseidon2Transcript, - lane: RlcLane, - step_idx: usize, - me_inputs: &[MeInstance], -) -> Result<(), PiCcsError> { - let lane_scope: &'static [u8] = match lane { - RlcLane::Main => b"main", - RlcLane::Val => b"val", - }; - - // v2: binds NC-channel fields (s_col, y_zcol) so RLC challenges depend on the full instance. - tr.append_message(b"fold/rlc_inputs/v2", lane_scope); - tr.append_u64s(b"step_idx", &[step_idx as u64]); - tr.append_u64s(b"me_count", &[me_inputs.len() as u64]); - - for me in me_inputs { - tr.append_fields(b"c_data", &me.c.data); - tr.append_u64s(b"m_in", &[me.m_in as u64]); - tr.append_message(b"me_fold_digest", &me.fold_digest); - - for limb in &me.r { - tr.append_fields(b"r_limb", &limb.as_coeffs()); - } - - tr.append_u64s(b"s_col_len", &[me.s_col.len() as u64]); - for sc in &me.s_col { - tr.append_fields(b"s_col_elem", &sc.as_coeffs()); - } - - tr.append_u64s(b"y_zcol_len", &[me.y_zcol.len() as u64]); - for yz in &me.y_zcol { - tr.append_fields(b"y_zcol_elem", &yz.as_coeffs()); - } - - tr.append_fields(b"X", me.X.as_slice()); - - for yj in &me.y { - for &y_elem in yj { - tr.append_fields(b"y_elem", &y_elem.as_coeffs()); - } - } - - for ysc in &me.y_scalars { - tr.append_fields(b"y_scalar", &ysc.as_coeffs()); - } - - tr.append_u64s(b"c_step_coords_len", &[me.c_step_coords.len() as u64]); - tr.append_fields(b"c_step_coords", &me.c_step_coords); - tr.append_u64s(b"u_offset", &[me.u_offset as u64]); - tr.append_u64s(b"u_len", &[me.u_len as u64]); - } - - Ok(()) -} - -fn prove_rlc_dec_lane( - mode: &FoldingMode, - lane: RlcLane, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s: &CcsStructure, - ccs_sparse_cache: Option<&SparseCache>, - cpu_bus: Option<&neo_memory::cpu::BusLayout>, - ring: &ccs::RotRing, - ell_d: usize, - k_dec: usize, - step_idx: usize, - me_inputs: &[MeInstance], - wit_inputs: &[&Mat], - want_witnesses: bool, - l: &L, - mixers: CommitMixers, -) -> Result<(RlcDecProof, Vec>), PiCcsError> -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - if me_inputs.is_empty() { - let prefix = match lane { - RlcLane::Main => "", - RlcLane::Val => "val-lane ", - }; - return Err(PiCcsError::InvalidInput(format!( - "step {}: {prefix}RLC input batch is empty", - step_idx - ))); - } - if wit_inputs.len() != me_inputs.len() { - let prefix = match lane { - RlcLane::Main => "", - RlcLane::Val => "val-lane ", - }; - return Err(PiCcsError::InvalidInput(format!( - "step {}: {prefix}RLC witness count mismatch (me_inputs.len()={}, wit_inputs.len()={})", - step_idx, - me_inputs.len(), - wit_inputs.len() - ))); - } - - bind_rlc_inputs(tr, lane, step_idx, me_inputs)?; - let rlc_rhos = ccs::sample_rot_rhos_n(tr, params, ring, me_inputs.len())?; - let (mut rlc_parent, Z_mix) = if me_inputs.len() == 1 { - if rlc_rhos.len() != 1 { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_RLC(k=1): |rhos| must equal |inputs|", - step_idx - ))); - } - let inp = &me_inputs[0]; - - // Match `neo_reductions::api::rlc_with_commit` semantics for k=1 without cloning Z. - let inputs_c = vec![inp.c.clone()]; - let c = (mixers.mix_rhos_commits)(&rlc_rhos, &inputs_c); - - // Recompute y_scalars from digits (canonical). - let t = inp.y.len(); - if t < s.t() { - return Err(PiCcsError::InvalidInput(format!( - "step {}: Π_RLC(k=1): ME y.len() must be >= s.t() (got {}, s.t()={})", - step_idx, - t, - s.t() - ))); - } - for (j, row) in inp.y.iter().enumerate() { - if row.len() < D { - return Err(PiCcsError::InvalidInput(format!( - "step {}: Π_RLC(k=1): ME y[{}].len()={} must be >= D={}", - step_idx, - j, - row.len(), - D - ))); - } - } - let bK = K::from(F::from_u64(params.b as u64)); - let mut y_scalars = Vec::with_capacity(t); - for j in 0..t { - let mut sc = K::ZERO; - let mut pow = K::ONE; - for rho in 0..D { - sc += pow * inp.y[j][rho]; - pow *= bK; - } - y_scalars.push(sc); - } - - let out = MeInstance:: { - c_step_coords: vec![], - u_offset: 0, - u_len: 0, - c, - X: inp.X.clone(), - r: inp.r.clone(), - s_col: inp.s_col.clone(), - y: inp.y.clone(), - y_scalars, - y_zcol: inp.y_zcol.clone(), - m_in: inp.m_in, - fold_digest: inp.fold_digest, - }; - - (out, Cow::Borrowed(wit_inputs[0])) - } else { - // `ccs::rlc_with_commit` expects an owned slice; avoid changing the public API by cloning here. - let wit_owned: Vec> = wit_inputs.iter().map(|m| (*m).clone()).collect(); - let (out, Z_mix) = ccs::rlc_with_commit( - mode.clone(), - s, - params, - &rlc_rhos, - me_inputs, - &wit_owned, - ell_d, - mixers.mix_rhos_commits, - )?; - (out, Cow::Owned(Z_mix)) - }; - - let Z_mix = Z_mix.as_ref(); - - let can_stream_dec = - !want_witnesses && has_global_pp_for_dims(D, s.m) && !cpu_bus.map(|b| b.bus_cols > 0).unwrap_or(false); - - let (mut dec_children, ok_y, ok_X, ok_c, maybe_wits) = if can_stream_dec { - // Memory-optimized DEC: compute children + commitments without materializing Z_split. - // - // This is only used when we don't need to carry digit witnesses forward. - let (children, _child_cs, ok_y, ok_X, ok_c) = dec_stream_no_witness( - params, - s, - &rlc_parent, - Z_mix, - ell_d, - k_dec, - mixers.combine_b_pows, - ccs_sparse_cache, - )?; - (children, ok_y, ok_X, ok_c, Vec::new()) - } else { - // Standard DEC: materialize digit matrices (needed when carrying witnesses forward). - let (Z_split, digit_nonzero) = ccs::split_b_matrix_k_with_nonzero_flags(Z_mix, k_dec, params.b)?; - let zero_c = Cmt::zeros(rlc_parent.c.d, rlc_parent.c.kappa); - let child_cs: Vec = { - #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] - { - Z_split - .par_iter() - .enumerate() - .map(|(idx, Zi)| { - if digit_nonzero[idx] { - l.commit(Zi) - } else { - zero_c.clone() - } - }) - .collect() - } - #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] - { - Z_split - .iter() - .enumerate() - .map(|(idx, Zi)| { - if digit_nonzero[idx] { - l.commit(Zi) - } else { - zero_c.clone() - } - }) - .collect() - } - }; - let (dec_children, ok_y, ok_X, ok_c) = ccs::dec_children_with_commit_cached( - mode.clone(), - s, - params, - &rlc_parent, - &Z_split, - ell_d, - &child_cs, - mixers.combine_b_pows, - ccs_sparse_cache, - ); - (dec_children, ok_y, ok_X, ok_c, Z_split) - }; - if !(ok_y && ok_X && ok_c) { - let lane_label = match lane { - RlcLane::Main => "DEC", - RlcLane::Val => "DEC(val)", - }; - return Err(PiCcsError::ProtocolError(format!( - "{} public check failed at step {} (y={}, X={}, c={})", - lane_label, step_idx, ok_y, ok_X, ok_c - ))); - } - - // Shared CPU bus: carry the implicit bus openings through Π_RLC/Π_DEC so they remain - // part of the folded instance (and are checked by public DEC verification). - if let Some(bus) = cpu_bus { - if bus.bus_cols > 0 { - let core_t = s.t(); - crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( - params, - bus, - core_t, - Z_mix, - &mut rlc_parent, - )?; - for (child, Zi) in dec_children.iter_mut().zip(maybe_wits.iter()) { - crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance(params, bus, core_t, Zi, child)?; - } - } - } - - Ok(( - RlcDecProof { - rlc_rhos, - rlc_parent, - dec_children, - }, - maybe_wits, - )) -} - -fn verify_rlc_dec_lane( - lane: RlcLane, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s: &CcsStructure, - ring: &ccs::RotRing, - ell_d: usize, - mixers: CommitMixers, - step_idx: usize, - rlc_inputs: &[MeInstance], - rlc_rhos: &[Mat], - rlc_parent: &MeInstance, - dec_children: &[MeInstance], -) -> Result<(), PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - bind_rlc_inputs(tr, lane, step_idx, rlc_inputs)?; - - if rlc_rhos.len() != rlc_inputs.len() { - let prefix = match lane { - RlcLane::Main => "", - RlcLane::Val => "val-lane ", - }; - return Err(PiCcsError::InvalidInput(format!( - "step {}: {}RLC ρ count mismatch (expected {}, got {})", - step_idx, - prefix, - rlc_inputs.len(), - rlc_rhos.len() - ))); - } - - let rhos_from_tr = ccs::sample_rot_rhos_n(tr, params, ring, rlc_inputs.len())?; - for (j, (sampled, stored)) in rhos_from_tr.iter().zip(rlc_rhos.iter()).enumerate() { - if sampled.as_slice() != stored.as_slice() { - return Err(PiCcsError::ProtocolError(match lane { - RlcLane::Main => format!("step {}: RLC ρ #{} mismatch: transcript vs proof", step_idx, j), - RlcLane::Val => format!("step {}: val-lane RLC ρ #{} mismatch: transcript vs proof", step_idx, j), - })); - } - } - - let parent_pub = ccs::rlc_public(s, params, rlc_rhos, rlc_inputs, mixers.mix_rhos_commits, ell_d)?; - - let prefix = match lane { - RlcLane::Main => "", - RlcLane::Val => "val-lane ", - }; - if parent_pub.m_in != rlc_parent.m_in { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC m_in mismatch (public={}, proof={})", - step_idx, parent_pub.m_in, rlc_parent.m_in - ))); - } - if parent_pub.fold_digest != rlc_parent.fold_digest { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC fold_digest mismatch", - step_idx - ))); - } - if parent_pub.c_step_coords != rlc_parent.c_step_coords { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC c_step_coords mismatch", - step_idx - ))); - } - if parent_pub.u_offset != rlc_parent.u_offset { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC u_offset mismatch", - step_idx - ))); - } - if parent_pub.u_len != rlc_parent.u_len { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC u_len mismatch", - step_idx - ))); - } - if parent_pub.X != rlc_parent.X { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC X mismatch", - step_idx - ))); - } - if parent_pub.c != rlc_parent.c { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC commitment mismatch", - step_idx - ))); - } - if parent_pub.r != rlc_parent.r { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC r mismatch", - step_idx - ))); - } - if parent_pub.s_col != rlc_parent.s_col { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC s_col mismatch", - step_idx - ))); - } - if parent_pub.y != rlc_parent.y { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC y mismatch", - step_idx - ))); - } - if parent_pub.y_scalars != rlc_parent.y_scalars { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC y_scalars mismatch", - step_idx - ))); - } - if parent_pub.y_zcol != rlc_parent.y_zcol { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC y_zcol mismatch", - step_idx - ))); - } - - if rlc_parent.X.rows() != D || rlc_parent.X.cols() != rlc_parent.m_in { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}RLC parent X shape {}x{} does not match m_in={}", - step_idx, - rlc_parent.X.rows(), - rlc_parent.X.cols(), - rlc_parent.m_in - ))); - } - if !dec_children.is_empty() { - validate_me_batch_invariants(dec_children, "verify step dec children")?; - for (child_idx, child) in dec_children.iter().enumerate() { - if child.m_in != rlc_parent.m_in { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}DEC child[{child_idx}] has m_in={}, expected {}", - step_idx, child.m_in, rlc_parent.m_in - ))); - } - if child.fold_digest != rlc_parent.fold_digest { - return Err(PiCcsError::ProtocolError(format!( - "step {}: {prefix}DEC child[{child_idx}] fold_digest mismatch", - step_idx - ))); - } - } - } - - if !ccs::verify_dec_public(s, params, rlc_parent, dec_children, mixers.combine_b_pows, ell_d) { - return Err(PiCcsError::ProtocolError(match lane { - RlcLane::Main => format!("step {}: DEC public check failed", step_idx), - RlcLane::Val => format!("step {}: val-lane DEC public check failed", step_idx), - })); - } - - Ok(()) -} +pub use core_utils::{absorb_step_memory, check_step_linking, CommitMixers, StepLinkingConfig}; +pub use verifier_and_api::*; -#[cfg(feature = "paper-exact")] -fn crosscheck_route_a_ccs_step( - cfg: &neo_reductions::engines::CrosscheckCfg, - step_idx: usize, - params: &NeoParams, - s: &CcsStructure, - cpu_bus: &neo_memory::cpu::BusLayout, - mcs_inst: &neo_ccs::McsInstance, - mcs_wit: &neo_ccs::McsWitness, - me_inputs: &[MeInstance], - me_witnesses: &[Mat], - ccs_out: &[MeInstance], - ccs_proof: &crate::PiCcsProof, - ell_d: usize, - ell_n: usize, - ell_m: usize, - d_sc: usize, - fold_digest: [u8; 32], - log: &L, -) -> Result<(), PiCcsError> -where - L: SModuleHomomorphism + Sync, -{ - let want_rounds_total = ell_n - .checked_add(ell_d) - .ok_or_else(|| PiCcsError::ProtocolError("ell_n + ell_d overflow".into()))?; - if ccs_proof.sumcheck_rounds.len() != want_rounds_total { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck expects {} CCS sumcheck rounds, got {}", - step_idx, - want_rounds_total, - ccs_proof.sumcheck_rounds.len(), - ))); - } - if ccs_proof.sumcheck_challenges.len() != want_rounds_total { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck expects {} CCS sumcheck challenges, got {}", - step_idx, - want_rounds_total, - ccs_proof.sumcheck_challenges.len(), - ))); - } - let (s_col_prime, alpha_prime_nc) = if ccs_proof.variant == crate::optimized_engine::PiCcsProofVariant::SplitNcV1 { - let want_nc_rounds_total = ell_m - .checked_add(ell_d) - .ok_or_else(|| PiCcsError::ProtocolError("ell_m + ell_d overflow".into()))?; - if ccs_proof.sumcheck_rounds_nc.len() != want_nc_rounds_total { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck expects {} NC sumcheck rounds, got {}", - step_idx, - want_nc_rounds_total, - ccs_proof.sumcheck_rounds_nc.len(), - ))); - } - if ccs_proof.sumcheck_challenges_nc.len() != want_nc_rounds_total { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck expects {} NC sumcheck challenges, got {}", - step_idx, - want_nc_rounds_total, - ccs_proof.sumcheck_challenges_nc.len(), - ))); - } - ccs_proof.sumcheck_challenges_nc.split_at(ell_m) - } else { - (&[][..], &[][..]) - }; - - let (r_prime, alpha_prime) = ccs_proof.sumcheck_challenges.split_at(ell_n); - let r_inputs = me_inputs.first().map(|mi| mi.r.as_slice()); - - if cfg.initial_sum { - let lhs_exact = crate::paper_exact_engine::sum_q_over_hypercube_paper_exact( - s, - params, - core::slice::from_ref(mcs_wit), - me_witnesses, - &ccs_proof.challenges_public, - ell_d, - ell_n, - r_inputs, - ); - let initial_sum_prover = match ccs_proof.sc_initial_sum { - Some(x) => x, - None => ccs_proof - .sumcheck_rounds - .first() - .map(|p0| poly_eval_k(p0, K::ZERO) + poly_eval_k(p0, K::ONE)) - .ok_or_else(|| PiCcsError::ProtocolError("crosscheck: missing sumcheck round 0".into()))?, - }; - if lhs_exact != initial_sum_prover { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck initial sum mismatch (optimized vs paper-exact)", - step_idx - ))); - } - } - - if cfg.per_round { - let mut paper_oracle = crate::paper_exact_engine::oracle::PaperExactOracle::new( - s, - params, - core::slice::from_ref(mcs_wit), - me_witnesses, - ccs_proof.challenges_public.clone(), - ell_d, - ell_n, - d_sc, - r_inputs, - ); - - let mut any_mismatch = false; - for (round_idx, (opt_coeffs, &challenge)) in ccs_proof - .sumcheck_rounds - .iter() - .zip(ccs_proof.sumcheck_challenges.iter()) - .enumerate() - { - let deg = paper_oracle.degree_bound(); - let xs: Vec = (0..=deg).map(|t| K::from(F::from_u64(t as u64))).collect(); - let paper_evals = paper_oracle.evals_at(&xs); - - for (&x, &expected) in xs.iter().zip(paper_evals.iter()) { - let actual = poly_eval_k(opt_coeffs, x); - if actual != expected { - any_mismatch = true; - if cfg.fail_fast { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck round {} polynomial mismatch", - step_idx, round_idx - ))); - } - } - } - - paper_oracle.fold(challenge); - } - if any_mismatch { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck per-round polynomial mismatch", - step_idx - ))); - } - } - - if cfg.terminal { - let running_sum_prover = if let Some(initial) = ccs_proof.sc_initial_sum { - let mut running = initial; - for (coeffs, &ri) in ccs_proof - .sumcheck_rounds - .iter() - .zip(ccs_proof.sumcheck_challenges.iter()) - { - running = poly_eval_k(coeffs, ri); - } - running - } else { - ccs_proof - .sumcheck_rounds - .first() - .map(|p0| poly_eval_k(p0, K::ZERO) + poly_eval_k(p0, K::ONE)) - .unwrap_or(K::ZERO) - }; - - let rhs_fe = crate::paper_exact_engine::rhs_terminal_identity_fe_paper_exact( - s, - params, - &ccs_proof.challenges_public, - r_prime, - alpha_prime, - ccs_out, - r_inputs, - ); - let (lhs_fe, _rhs_unused) = crate::paper_exact_engine::q_eval_at_ext_point_fe_paper_exact_with_inputs( - s, - params, - core::slice::from_ref(mcs_wit), - me_witnesses, - alpha_prime, - r_prime, - &ccs_proof.challenges_public, - r_inputs, - ); - if rhs_fe != lhs_fe || rhs_fe != running_sum_prover { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck FE terminal evaluation claim mismatch", - step_idx - ))); - } - - let rhs_nc = crate::paper_exact_engine::rhs_terminal_identity_nc_paper_exact( - params, - &ccs_proof.challenges_public, - s_col_prime, - alpha_prime_nc, - ccs_out, - ); - if rhs_nc != ccs_proof.sumcheck_final_nc { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck NC terminal evaluation claim mismatch", - step_idx - ))); - } - } - - if cfg.outputs { - let mut out_me_ref = build_me_outputs_paper_exact( - s, - params, - core::slice::from_ref(mcs_inst), - core::slice::from_ref(mcs_wit), - me_inputs, - me_witnesses, - r_prime, - s_col_prime, - ell_d, - fold_digest, - log, - ); - - if cpu_bus.bus_cols > 0 { - let core_t = s.t(); - if out_me_ref.len() != 1 + me_witnesses.len() { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck CCS output count mismatch for bus openings (out_me_ref.len()={}, expected {})", - step_idx, - out_me_ref.len(), - 1 + me_witnesses.len() - ))); - } - - crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( - params, - cpu_bus, - core_t, - &mcs_wit.Z, - &mut out_me_ref[0], - )?; - for (out, Z) in out_me_ref.iter_mut().skip(1).zip(me_witnesses.iter()) { - crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance(params, cpu_bus, core_t, Z, out)?; - } - } - - if out_me_ref.len() != ccs_out.len() { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output length mismatch (paper={}, optimized={})", - step_idx, - out_me_ref.len(), - ccs_out.len() - ))); - } - - for (idx, (a, b)) in out_me_ref.iter().zip(ccs_out.iter()).enumerate() { - if a.m_in != b.m_in { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] m_in mismatch (paper={}, optimized={})", - step_idx, a.m_in, b.m_in - ))); - } - if a.r != b.r { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] r mismatch", - step_idx - ))); - } - if a.s_col != b.s_col { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] s_col mismatch", - step_idx - ))); - } - if a.c.data != b.c.data { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] commitment mismatch", - step_idx - ))); - } - if a.y.len() != b.y.len() { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] y.len mismatch (paper={}, optimized={})", - step_idx, - a.y.len(), - b.y.len() - ))); - } - for (j, (ya, yb)) in a.y.iter().zip(b.y.iter()).enumerate() { - if ya != yb { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] y row {j} mismatch", - step_idx - ))); - } - } - if a.y_scalars != b.y_scalars { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] y_scalars mismatch", - step_idx - ))); - } - if a.y_zcol != b.y_zcol { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] y_zcol mismatch", - step_idx - ))); - } - if a.X.rows() != b.X.rows() || a.X.cols() != b.X.cols() { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] X dims mismatch (paper={}x{}, optimized={}x{})", - step_idx, - a.X.rows(), - a.X.cols(), - b.X.rows(), - b.X.cols() - ))); - } - for r in 0..a.X.rows() { - for c in 0..a.X.cols() { - if a.X[(r, c)] != b.X[(r, c)] { - return Err(PiCcsError::ProtocolError(format!( - "step {}: crosscheck output[{idx}] X mismatch at ({},{})", - step_idx, r, c - ))); - } - } - } - } - } - - Ok(()) -} - -// ============================================================================ -// Shard Proving -// ============================================================================ - -#[derive(Clone)] -pub(crate) struct ShardProverContext { - pub ccs_mat_digest: Vec, - pub ccs_sparse_cache: Option>>, -} - -#[inline] -fn mode_uses_sparse_cache(mode: &FoldingMode) -> bool { - match mode { - FoldingMode::Optimized => true, - #[cfg(feature = "paper-exact")] - FoldingMode::OptimizedWithCrosscheck(_) => true, - #[cfg(feature = "paper-exact")] - FoldingMode::PaperExact => false, - } -} - -fn fold_shard_prove_impl( - collect_val_lane_wits: bool, - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepWitnessBundle], - step_idx_offset: usize, - acc_init: &[MeInstance], - acc_wit_init: &[Mat], - l: &L, - mixers: CommitMixers, - ob: Option<(&crate::output_binding::OutputBindingConfig, &[F])>, - prover_ctx: Option<&ShardProverContext>, - mut step_prove_ms_out: Option<&mut Vec>, -) -> Result<(ShardProof, Vec>, Vec>), PiCcsError> -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let (s, cpu_bus) = crate::memory_sidecar::cpu_bus::prepare_ccs_for_shared_cpu_bus_steps(s_me, steps)?; - let dims = utils::build_dims_and_policy(params, s)?; - let utils::Dims { - ell_d, - ell_n, - ell_m, - ell, - d_sc, - .. - } = dims; - let ccs_sparse_cache: Option>> = if mode_uses_sparse_cache(&mode) { - Some( - prover_ctx - .and_then(|ctx| ctx.ccs_sparse_cache.clone()) - .unwrap_or_else(|| Arc::new(SparseCache::build(s))), - ) - } else { - None - }; - let ccs_mat_digest = prover_ctx - .map(|ctx| ctx.ccs_mat_digest.clone()) - .unwrap_or_else(|| utils::digest_ccs_matrices_with_sparse_cache(s, ccs_sparse_cache.as_deref())); - if mode_uses_sparse_cache(&mode) && ccs_sparse_cache.is_none() { - return Err(PiCcsError::ProtocolError( - "missing SparseCache for optimized mode".into(), - )); - } - let k_dec = params.k_rho as usize; - let ring = ccs::RotRing::goldilocks(); - - if acc_init.len() != acc_wit_init.len() { - return Err(PiCcsError::InvalidInput(format!( - "acc_init.len()={} != acc_wit_init.len()={}", - acc_init.len(), - acc_wit_init.len() - ))); - } - - // Initialize accumulator - let mut accumulator = acc_init.to_vec(); - let mut accumulator_wit = acc_wit_init.to_vec(); - - let mut step_proofs = Vec::with_capacity(steps.len()); - let mut val_lane_wits: Vec> = Vec::new(); - let mut prev_twist_decoded: Option> = None; - let mut output_proof: Option = None; - - if ob.is_some() && steps.is_empty() { - return Err(PiCcsError::InvalidInput("output binding requires >= 1 step".into())); - } - - for (idx, step) in steps.iter().enumerate() { - let step_idx = step_idx_offset - .checked_add(idx) - .ok_or_else(|| PiCcsError::InvalidInput("step index overflow".into()))?; - let step_start = time_now(); - crate::memory_sidecar::memory::absorb_step_memory_witness(tr, step); - - let include_ob = ob.is_some() && (idx + 1 == steps.len()); - let mut ob_time_claim: Option = None; - let mut ob_r_prime: Option> = None; - - // Output binding is injected only on the final step, and must run before sampling Route-A `r_time`. - if include_ob { - let (cfg, final_memory_state) = - ob.ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but config missing".into()))?; - - if output_proof.is_some() { - return Err(PiCcsError::ProtocolError( - "output binding already attached (internal error)".into(), - )); - } - - if cfg.mem_idx >= step.mem_instances.len() { - return Err(PiCcsError::InvalidInput("output binding mem_idx out of range".into())); - } - let expected_k = 1usize - .checked_shl(cfg.num_bits as u32) - .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^num_bits overflow".into()))?; - if final_memory_state.len() != expected_k { - return Err(PiCcsError::InvalidInput(format!( - "output binding: final_memory_state.len()={} != 2^num_bits={}", - final_memory_state.len(), - expected_k - ))); - } - let mem_inst = &step.mem_instances[cfg.mem_idx].0; - if mem_inst.k != expected_k { - return Err(PiCcsError::InvalidInput(format!( - "output binding: cfg.num_bits implies k={}, but mem_inst.k={}", - expected_k, mem_inst.k - ))); - } - let ell_addr = mem_inst.twist_layout().lanes[0].ell_addr; - if ell_addr != cfg.num_bits { - return Err(PiCcsError::InvalidInput(format!( - "output binding: cfg.num_bits={}, but twist_layout.ell_addr={}", - cfg.num_bits, ell_addr - ))); - } - - tr.append_message(b"shard/output_binding_start", &(step_idx as u64).to_le_bytes()); - tr.append_u64s(b"output_binding/mem_idx", &[cfg.mem_idx as u64]); - tr.append_u64s(b"output_binding/num_bits", &[cfg.num_bits as u64]); - - let (output_sc, r_prime) = neo_memory::output_check::generate_output_sumcheck_proof_and_challenges( - tr, - cfg.num_bits, - cfg.program_io.clone(), - final_memory_state, - ) - .map_err(|e| PiCcsError::ProtocolError(format!("output sumcheck failed: {e:?}")))?; - - output_proof = Some(neo_memory::output_check::OutputBindingProof { output_sc }); - ob_r_prime = Some(r_prime); - } - - let (mcs_inst, mcs_wit) = &step.mcs; - - // k = accumulator.len() + 1 - let k = accumulator.len() + 1; - - // -------------------------------------------------------------------- - // Route A: Shared-challenge batched sum-check for time/row rounds. - // -------------------------------------------------------------------- - // - // 1) Bind CCS header + ME inputs - // 2) Sample CCS challenges (α, β, γ) and bind initial sum - // 3) Build CCS oracle + lazy Twist/Shout oracles - // 4) Run ONE batched sum-check for the first ell_n rounds (row/time) - // 5) Finish CCS alone for remaining ell_d Ajtai rounds - // 6) Emit CCS + memory ME claims at the shared r_time and fold via RLC/DEC - - utils::bind_header_and_instances_with_digest( - tr, - params, - &s, - core::slice::from_ref(mcs_inst), - dims, - &ccs_mat_digest, - )?; - utils::bind_me_inputs(tr, &accumulator)?; - let mut ch = utils::sample_challenges(tr, ell_d, ell)?; - ch.beta_m = utils::sample_beta_m(tr, ell_m)?; - let ccs_initial_sum = claimed_initial_sum_from_inputs(&s, &ch, &accumulator); - tr.append_fields(b"sumcheck/initial_sum", &ccs_initial_sum.as_coeffs()); - - // Route A memory checks use a separate transcript-derived cycle point `r_cycle` - // to form χ_{r_cycle}(t) weights inside their sum-check polynomials. - let r_cycle: Vec = - ts::sample_ext_point(tr, b"route_a/r_cycle", b"route_a/cycle/0", b"route_a/cycle/1", ell_n); - - // CCS oracle (engine-selected). - // - // Keep the optimized oracle concrete so we can build outputs from its Ajtai precompute. - let mut ccs_oracle: CcsOracleDispatch<'_> = match mode.clone() { - FoldingMode::Optimized => { - let sparse = ccs_sparse_cache - .as_ref() - .ok_or_else(|| PiCcsError::ProtocolError("missing SparseCache for optimized mode".into()))?; - CcsOracleDispatch::Optimized( - neo_reductions::engines::optimized_engine::oracle::OptimizedOracle::new_with_sparse( - &s, - params, - core::slice::from_ref(mcs_wit), - &accumulator_wit, - ch.clone(), - ell_d, - ell_n, - d_sc, - accumulator.first().map(|mi| mi.r.as_slice()), - sparse.clone(), - ), - ) - } - #[cfg(feature = "paper-exact")] - FoldingMode::PaperExact => CcsOracleDispatch::PaperExact( - neo_reductions::engines::paper_exact_engine::oracle::PaperExactOracle::new( - &s, - params, - core::slice::from_ref(mcs_wit), - &accumulator_wit, - ch.clone(), - ell_d, - ell_n, - d_sc, - accumulator.first().map(|mi| mi.r.as_slice()), - ), - ), - #[cfg(feature = "paper-exact")] - FoldingMode::OptimizedWithCrosscheck(_) => { - let sparse = ccs_sparse_cache - .as_ref() - .ok_or_else(|| PiCcsError::ProtocolError("missing SparseCache for optimized mode".into()))?; - CcsOracleDispatch::Optimized( - neo_reductions::engines::optimized_engine::oracle::OptimizedOracle::new_with_sparse( - &s, - params, - core::slice::from_ref(mcs_wit), - &accumulator_wit, - ch.clone(), - ell_d, - ell_n, - d_sc, - accumulator.first().map(|mi| mi.r.as_slice()), - sparse.clone(), - ), - ) - } - }; - - let shout_pre = crate::memory_sidecar::memory::prove_shout_addr_pre_time( - tr, - params, - step, - Some(&cpu_bus), - ell_n, - &r_cycle, - step_idx, - )?; - let twist_pre = crate::memory_sidecar::memory::prove_twist_addr_pre_time( - tr, - params, - step, - Some(&cpu_bus), - ell_n, - &r_cycle, - )?; - let twist_read_claims: Vec = twist_pre.iter().map(|p| p.read_check_claim_sum).collect(); - let twist_write_claims: Vec = twist_pre.iter().map(|p| p.write_check_claim_sum).collect(); - let mut mem_oracles = crate::memory_sidecar::memory::build_route_a_memory_oracles( - params, step, ell_n, &r_cycle, &shout_pre, &twist_pre, - )?; - - if include_ob { - let (cfg, _final_memory_state) = - ob.ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but config missing".into()))?; - let r_prime = ob_r_prime - .as_ref() - .ok_or_else(|| PiCcsError::ProtocolError("output binding r_prime missing".into()))?; - let pre = twist_pre - .get(cfg.mem_idx) - .ok_or_else(|| PiCcsError::ProtocolError("output binding mem_idx out of range for twist_pre".into()))?; - - if pre.decoded.lanes.is_empty() { - return Err(PiCcsError::ProtocolError( - "output binding: Twist decoded lanes empty".into(), - )); - } - - let mut oracles: Vec> = Vec::with_capacity(pre.decoded.lanes.len()); - let mut claimed_sum = K::ZERO; - for lane in pre.decoded.lanes.iter() { - let (oracle, claim) = neo_memory::twist_oracle::TwistTotalIncOracleSparseTime::new( - lane.wa_bits.clone(), - lane.has_write.clone(), - lane.inc_at_write_addr.clone(), - r_prime, - ); - oracles.push(Box::new(oracle)); - claimed_sum += claim; - } - let oracle = crate::memory_sidecar::memory::SumRoundOracle::new(oracles); - - ob_time_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { - oracle: Box::new(oracle), - claimed_sum, - label: crate::output_binding::OB_INC_TOTAL_LABEL, - }); - } - - let crate::memory_sidecar::route_a_time::RouteABatchedTimeProverOutput { - r_time, - per_claim_results, - proof: batched_time, - } = crate::memory_sidecar::route_a_time::prove_route_a_batched_time( - tr, - step_idx, - ell_n, - d_sc, - ccs_initial_sum, - &mut ccs_oracle, - &mut mem_oracles, - step, - twist_read_claims, - twist_write_claims, - ob_time_claim, - )?; - - // Finish CCS Ajtai rounds alone, continuing from the CCS oracle state after ell_n folds. - let ccs_time_rounds = per_claim_results - .first() - .map(|r| r.round_polys.clone()) - .unwrap_or_default(); - let mut sumcheck_rounds = ccs_time_rounds; - let mut sumcheck_chals = r_time.clone(); - let ajtai_initial_sum = per_claim_results - .first() - .map(|r| r.final_value) - .unwrap_or(ccs_initial_sum); - - let mut ccs_ajtai = RoundOraclePrefix::new(&mut ccs_oracle, ell_d); - let (ajtai_rounds, ajtai_chals) = - run_sumcheck_prover_ds(tr, b"ccs/ajtai", step_idx, &mut ccs_ajtai, ajtai_initial_sum)?; - let mut running_sum = ajtai_initial_sum; - for (round_poly, &r_i) in ajtai_rounds.iter().zip(ajtai_chals.iter()) { - running_sum = poly_eval_k(round_poly, r_i); - } - sumcheck_rounds.extend_from_slice(&ajtai_rounds); - sumcheck_chals.extend_from_slice(&ajtai_chals); - - // -------------------------------------------------------------------- - // NC-only sumcheck (digit-range / norm-check) over {0,1}^{ell_m + ell_d}. - // -------------------------------------------------------------------- - let mut ccs_nc_oracle = neo_reductions::engines::optimized_engine::oracle::NcOracle::new( - &s, - params, - core::slice::from_ref(mcs_wit), - &accumulator_wit, - ch.clone(), - ell_d, - ell_m, - d_sc, - ); - let (sumcheck_rounds_nc, sumcheck_chals_nc) = - run_sumcheck_prover_ds(tr, b"ccs/nc", step_idx, &mut ccs_nc_oracle, K::ZERO)?; - let mut running_sum_nc = K::ZERO; - for (round_poly, &r_i) in sumcheck_rounds_nc.iter().zip(sumcheck_chals_nc.iter()) { - running_sum_nc = poly_eval_k(round_poly, r_i); - } - let (s_col, _alpha_prime_nc) = sumcheck_chals_nc.split_at(ell_m); - - // Build CCS ME outputs at r_time. - let fold_digest = tr.digest32(); - let mut ccs_out = match &mut ccs_oracle { - CcsOracleDispatch::Optimized(oracle) => oracle.build_me_outputs_from_ajtai_precomp( - core::slice::from_ref(mcs_inst), - &accumulator, - s_col, - fold_digest, - l, - ), - #[cfg(feature = "paper-exact")] - CcsOracleDispatch::PaperExact(_) => build_me_outputs_paper_exact( - &s, - params, - core::slice::from_ref(mcs_inst), - core::slice::from_ref(mcs_wit), - &accumulator, - &accumulator_wit, - &r_time, - s_col, - ell_d, - fold_digest, - l, - ), - }; - - // CCS oracle borrows accumulator_wit; drop before updating accumulator_wit at the end. - drop(ccs_oracle); - - // Shared CPU bus: append "implicit openings" for all bus columns without materializing - // bus copyout matrices into the CCS. - if cpu_bus.bus_cols > 0 { - let core_t = s.t(); - if ccs_out.len() != 1 + accumulator_wit.len() { - return Err(PiCcsError::ProtocolError(format!( - "CCS output count mismatch for bus openings (ccs_out.len()={}, expected {})", - ccs_out.len(), - 1 + accumulator_wit.len() - ))); - } - - crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( - params, - &cpu_bus, - core_t, - &mcs_wit.Z, - &mut ccs_out[0], - )?; - for (out, Z) in ccs_out.iter_mut().skip(1).zip(accumulator_wit.iter()) { - crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance(params, &cpu_bus, core_t, Z, out)?; - } - } - - if ccs_out.len() != k { - return Err(PiCcsError::ProtocolError(format!( - "Π_CCS returned {} outputs; expected k={k}", - ccs_out.len() - ))); - } - - let mut ccs_proof = crate::PiCcsProof::new(sumcheck_rounds, Some(ccs_initial_sum)); - ccs_proof.variant = crate::optimized_engine::PiCcsProofVariant::SplitNcV1; - ccs_proof.sumcheck_challenges = sumcheck_chals; - ccs_proof.sumcheck_rounds_nc = sumcheck_rounds_nc; - ccs_proof.sc_initial_sum_nc = Some(K::ZERO); - ccs_proof.sumcheck_challenges_nc = sumcheck_chals_nc; - ccs_proof.challenges_public = ch; - ccs_proof.sumcheck_final = running_sum; - ccs_proof.sumcheck_final_nc = running_sum_nc; - ccs_proof.header_digest = fold_digest.to_vec(); - - #[cfg(feature = "paper-exact")] - if let FoldingMode::OptimizedWithCrosscheck(cfg) = &mode { - crosscheck_route_a_ccs_step( - cfg, - step_idx, - params, - &s, - &cpu_bus, - mcs_inst, - mcs_wit, - &accumulator, - &accumulator_wit, - &ccs_out, - &ccs_proof, - ell_d, - ell_n, - ell_m, - d_sc, - fold_digest, - l, - )?; - } - - // Witnesses for CCS outputs: [Z_mcs, Z_seed...] (borrow; avoid multi-GB clones) - let mut outs_Z: Vec<&Mat> = Vec::with_capacity(k); - outs_Z.push(&mcs_wit.Z); - outs_Z.extend(accumulator_wit.iter()); - - // Memory sidecar: emit ME claims at the shared r_time (no fixed-challenge sumcheck). - let prev_step = (idx > 0).then(|| &steps[idx - 1]); - let prev_twist_decoded_ref = prev_twist_decoded.as_deref(); - let mut mem_proof = crate::memory_sidecar::memory::finalize_route_a_memory_prover( - tr, - params, - &cpu_bus, - &s, - step, - prev_step, - prev_twist_decoded_ref, - &mut mem_oracles, - &shout_pre.addr_pre, - &twist_pre, - &r_time, - mcs_inst.m_in, - step_idx, - )?; - prev_twist_decoded = Some(twist_pre.into_iter().map(|p| p.decoded).collect()); - - let y_len_total = s - .t() - .checked_add(cpu_bus.bus_cols) - .ok_or_else(|| PiCcsError::ProtocolError("t + bus_cols overflow".into()))?; - normalize_me_claims(&mut mem_proof.cpu_me_claims_val, ell_n, ell_d, y_len_total)?; - - validate_me_batch_invariants(&ccs_out, "prove step ccs outputs")?; - - let want_main_wits = collect_val_lane_wits || idx + 1 < steps.len(); - let (main_fold, Z_split) = prove_rlc_dec_lane( - &mode, - RlcLane::Main, - tr, - params, - &s, - ccs_sparse_cache.as_deref(), - Some(&cpu_bus), - &ring, - ell_d, - k_dec, - step_idx, - &ccs_out, - &outs_Z, - want_main_wits, - l, - mixers, - )?; - let RlcDecProof { - rlc_rhos: rhos, - rlc_parent: parent_pub, - dec_children: children, - } = main_fold; - - // -------------------------------------------------------------------- - // Phase 2: Second folding lane for Twist val-eval ME claims at r_val. - // -------------------------------------------------------------------- - let val_fold = if mem_proof.cpu_me_claims_val.is_empty() { - None - } else { - validate_me_batch_invariants(&mem_proof.cpu_me_claims_val, "prove step memory val outputs")?; - - tr.append_message(b"fold/val_lane_start", &(step_idx as u64).to_le_bytes()); - let mut val_wit_refs: Vec<&Mat> = Vec::with_capacity(mem_proof.cpu_me_claims_val.len()); - val_wit_refs.push(&mcs_wit.Z); - if let Some(prev) = prev_step { - val_wit_refs.push(&prev.mcs.1.Z); - } - if val_wit_refs.len() != mem_proof.cpu_me_claims_val.len() { - return Err(PiCcsError::ProtocolError(format!( - "Twist(val) witness count mismatch (have {}, need {})", - val_wit_refs.len(), - mem_proof.cpu_me_claims_val.len() - ))); - } - - // Avoid cloning/padding unless needed. - let need_pad = val_wit_refs.iter().any(|m| m.cols() != s.m); - let val_wits_owned: Option>> = if need_pad { - Some( - val_wit_refs - .iter() - .map(|m| ts::pad_mat_to_ccs_width(m, s.m)) - .collect::, _>>()?, - ) - } else { - None - }; - let val_wit_refs2: Vec<&Mat> = match &val_wits_owned { - Some(v) => v.iter().collect(), - None => val_wit_refs, - }; - let (val_fold, mut Z_split_val) = prove_rlc_dec_lane( - &mode, - RlcLane::Val, - tr, - params, - &s, - ccs_sparse_cache.as_deref(), - Some(&cpu_bus), - &ring, - ell_d, - k_dec, - step_idx, - &mem_proof.cpu_me_claims_val, - &val_wit_refs2, - collect_val_lane_wits, - l, - mixers, - )?; - - if collect_val_lane_wits { - val_lane_wits.extend(Z_split_val.drain(..)); - } - - Some(val_fold) - }; - - accumulator = children.clone(); - accumulator_wit = if want_main_wits { Z_split } else { Vec::new() }; - - step_proofs.push(StepProof { - fold: FoldStep { - ccs_out, - ccs_proof, - rlc_rhos: rhos, - rlc_parent: parent_pub, - dec_children: children, - }, - mem: mem_proof, - batched_time, - val_fold, - }); - - tr.append_message(b"fold/step_done", &(step_idx as u64).to_le_bytes()); - if let Some(out) = step_prove_ms_out.as_deref_mut() { - out.push(elapsed_ms(step_start)); - } - } - - Ok(( - ShardProof { - steps: step_proofs, - output_proof, - }, - accumulator_wit, - val_lane_wits, - )) -} - -pub fn fold_shard_prove( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepWitnessBundle], - acc_init: &[MeInstance], - acc_wit_init: &[Mat], - l: &L, - mixers: CommitMixers, -) -> Result -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( - false, - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - acc_wit_init, - l, - mixers, - None, - None, - None, - )?; - Ok(proof) -} - -pub(crate) fn fold_shard_prove_with_context( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepWitnessBundle], - acc_init: &[MeInstance], - acc_wit_init: &[Mat], - l: &L, - mixers: CommitMixers, - ctx: &ShardProverContext, -) -> Result -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( - false, - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - acc_wit_init, - l, - mixers, - None, - Some(ctx), - None, - )?; - Ok(proof) -} - -pub(crate) fn fold_shard_prove_with_context_and_step_timings( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepWitnessBundle], - acc_init: &[MeInstance], - acc_wit_init: &[Mat], - l: &L, - mixers: CommitMixers, - ctx: &ShardProverContext, -) -> Result<(ShardProof, Vec), PiCcsError> -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let mut step_prove_ms = Vec::with_capacity(steps.len()); - let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( - false, - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - acc_wit_init, - l, - mixers, - None, - Some(ctx), - Some(&mut step_prove_ms), - )?; - Ok((proof, step_prove_ms)) -} - -pub fn fold_shard_prove_with_output_binding( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepWitnessBundle], - acc_init: &[MeInstance], - acc_wit_init: &[Mat], - l: &L, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, - final_memory_state: &[F], -) -> Result -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( - false, - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - acc_wit_init, - l, - mixers, - Some((ob_cfg, final_memory_state)), - None, - None, - )?; - Ok(proof) -} - -pub(crate) fn fold_shard_prove_with_output_binding_with_context( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepWitnessBundle], - acc_init: &[MeInstance], - acc_wit_init: &[Mat], - l: &L, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, - final_memory_state: &[F], - ctx: &ShardProverContext, -) -> Result -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( - false, - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - acc_wit_init, - l, - mixers, - Some((ob_cfg, final_memory_state)), - Some(ctx), - None, - )?; - Ok(proof) -} - -pub fn fold_shard_prove_with_witnesses( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepWitnessBundle], - acc_init: &[MeInstance], - acc_wit_init: &[Mat], - l: &L, - mixers: CommitMixers, -) -> Result<(ShardProof, ShardFoldOutputs, ShardFoldWitnesses), PiCcsError> -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let (proof, final_main_wits, val_lane_wits) = fold_shard_prove_impl( - true, - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - acc_wit_init, - l, - mixers, - None, - None, - None, - )?; - let outputs = proof.compute_fold_outputs(acc_init); - if outputs.obligations.main.len() != final_main_wits.len() { - return Err(PiCcsError::ProtocolError(format!( - "final main witness count mismatch (have {}, need {})", - final_main_wits.len(), - outputs.obligations.main.len() - ))); - } - if outputs.obligations.val.len() != val_lane_wits.len() { - return Err(PiCcsError::ProtocolError(format!( - "val-lane witness count mismatch (have {}, need {})", - val_lane_wits.len(), - outputs.obligations.val.len() - ))); - } - Ok(( - proof, - outputs, - ShardFoldWitnesses { - final_main_wits, - val_lane_wits, - }, - )) -} - -/// Same as `fold_shard_prove_with_witnesses`, but offsets the per-step transcript index by `step_idx_offset`. -/// -/// This is useful for "continuation" style proving across multiple calls while preserving a globally -/// increasing step index for transcript domain separation. -pub fn fold_shard_prove_with_witnesses_with_step_offset( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepWitnessBundle], - acc_init: &[MeInstance], - acc_wit_init: &[Mat], - l: &L, - mixers: CommitMixers, - step_idx_offset: usize, -) -> Result<(ShardProof, ShardFoldOutputs, ShardFoldWitnesses), PiCcsError> -where - L: SModuleHomomorphism + Sync, - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let (proof, final_main_wits, val_lane_wits) = fold_shard_prove_impl( - true, - mode, - tr, - params, - s_me, - steps, - step_idx_offset, - acc_init, - acc_wit_init, - l, - mixers, - None, - None, - None, - )?; - let outputs = proof.compute_fold_outputs(acc_init); - if outputs.obligations.main.len() != final_main_wits.len() { - return Err(PiCcsError::ProtocolError(format!( - "final main witness count mismatch (have {}, need {})", - final_main_wits.len(), - outputs.obligations.main.len() - ))); - } - if outputs.obligations.val.len() != val_lane_wits.len() { - return Err(PiCcsError::ProtocolError(format!( - "val-lane witness count mismatch (have {}, need {})", - val_lane_wits.len(), - outputs.obligations.val.len() - ))); - } - Ok(( - proof, - outputs, - ShardFoldWitnesses { - final_main_wits, - val_lane_wits, - }, - )) -} - -// ============================================================================ -// Shard Verification -// ============================================================================ - -fn fold_shard_verify_impl( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - step_idx_offset: usize, - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - ob_cfg: Option<&crate::output_binding::OutputBindingConfig>, - prover_ctx: Option<&ShardProverContext>, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - let (s, cpu_bus) = crate::memory_sidecar::cpu_bus::prepare_ccs_for_shared_cpu_bus_steps(s_me, steps)?; - let dims = utils::build_dims_and_policy(params, s)?; - let utils::Dims { - ell_d, - ell_n, - ell_m, - ell, - d_sc, - .. - } = dims; - let ring = ccs::RotRing::goldilocks(); - - if steps.len() != proof.steps.len() { - return Err(PiCcsError::InvalidInput(format!( - "step count mismatch: public {} vs proof {}", - steps.len(), - proof.steps.len() - ))); - } - if ob_cfg.is_some() && steps.is_empty() { - return Err(PiCcsError::InvalidInput("output binding requires >= 1 step".into())); - } - if ob_cfg.is_none() && proof.output_proof.is_some() { - return Err(PiCcsError::InvalidInput( - "shard proof contains output binding, but verifier did not supply OutputBindingConfig".into(), - )); - } - if ob_cfg.is_some() && proof.output_proof.is_none() { - return Err(PiCcsError::InvalidInput( - "verifier supplied OutputBindingConfig, but shard proof has no output binding".into(), - )); - } - - let mut accumulator = acc_init.to_vec(); - let mut val_lane_obligations: Vec> = Vec::new(); - let ccs_sparse_cache: Option>> = if mode_uses_sparse_cache(&mode) { - Some( - prover_ctx - .and_then(|ctx| ctx.ccs_sparse_cache.clone()) - .unwrap_or_else(|| Arc::new(SparseCache::build(s))), - ) - } else { - None - }; - let ccs_mat_digest = prover_ctx - .map(|ctx| ctx.ccs_mat_digest.clone()) - .unwrap_or_else(|| utils::digest_ccs_matrices_with_sparse_cache(s, ccs_sparse_cache.as_deref())); - - for (idx, (step, step_proof)) in steps.iter().zip(proof.steps.iter()).enumerate() { - let step_idx = step_idx_offset - .checked_add(idx) - .ok_or_else(|| PiCcsError::InvalidInput("step index overflow".into()))?; - absorb_step_memory(tr, step); - - let include_ob = ob_cfg.is_some() && (idx + 1 == steps.len()); - let mut ob_state: Option = None; - let mut ob_inc_total_degree_bound: Option = None; - - if include_ob { - let cfg = - ob_cfg.ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but config missing".into()))?; - let ob_proof = proof - .output_proof - .as_ref() - .ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but proof missing".into()))?; - - if cfg.mem_idx >= step.mem_insts.len() { - return Err(PiCcsError::InvalidInput("output binding mem_idx out of range".into())); - } - let mem_inst = step - .mem_insts - .get(cfg.mem_idx) - .ok_or_else(|| PiCcsError::InvalidInput("output binding mem_idx out of range".into()))?; - let expected_k = 1usize - .checked_shl(cfg.num_bits as u32) - .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^num_bits overflow".into()))?; - if mem_inst.k != expected_k { - return Err(PiCcsError::InvalidInput(format!( - "output binding: cfg.num_bits implies k={}, but mem_inst.k={}", - expected_k, mem_inst.k - ))); - } - let ell_addr = mem_inst.twist_layout().lanes[0].ell_addr; - if ell_addr != cfg.num_bits { - return Err(PiCcsError::InvalidInput(format!( - "output binding: cfg.num_bits={}, but twist_layout.ell_addr={}", - cfg.num_bits, ell_addr - ))); - } - - tr.append_message(b"shard/output_binding_start", &(step_idx as u64).to_le_bytes()); - tr.append_u64s(b"output_binding/mem_idx", &[cfg.mem_idx as u64]); - tr.append_u64s(b"output_binding/num_bits", &[cfg.num_bits as u64]); - - let state = neo_memory::output_check::verify_output_sumcheck_rounds_get_state( - tr, - cfg.num_bits, - cfg.program_io.clone(), - &ob_proof.output_sc, - ) - .map_err(|e| PiCcsError::ProtocolError(format!("output sumcheck failed: {e:?}")))?; - ob_inc_total_degree_bound = Some(2 + ell_addr); - ob_state = Some(state); - } - - let mcs_inst = &step.mcs_inst; - - // -------------------------------------------------------------------- - // Route A: Verify shared-challenge batched sum-check (time/row rounds), - // then finish CCS Ajtai rounds, then proceed with RLC→DEC as before. - // -------------------------------------------------------------------- - - // Bind CCS header + ME inputs and sample public challenges. - utils::bind_header_and_instances_with_digest( - tr, - params, - &s, - core::slice::from_ref(mcs_inst), - dims, - &ccs_mat_digest, - )?; - utils::bind_me_inputs(tr, &accumulator)?; - let mut ch = utils::sample_challenges(tr, ell_d, ell)?; - if step_proof.fold.ccs_proof.variant == crate::optimized_engine::PiCcsProofVariant::SplitNcV1 { - ch.beta_m = utils::sample_beta_m(tr, ell_m)?; - } - let expected_ch = &step_proof.fold.ccs_proof.challenges_public; - if expected_ch.alpha != ch.alpha - || expected_ch.beta_a != ch.beta_a - || expected_ch.beta_r != ch.beta_r - || expected_ch.beta_m != ch.beta_m - || expected_ch.gamma != ch.gamma - { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS challenges_public mismatch", - idx - ))); - } - - // Public initial sum T for CCS sumcheck (engine-selected). - let claimed_initial = match &mode { - FoldingMode::Optimized => crate::optimized_engine::claimed_initial_sum_from_inputs(&s, &ch, &accumulator), - #[cfg(feature = "paper-exact")] - FoldingMode::PaperExact => { - crate::paper_exact_engine::claimed_initial_sum_from_inputs(&s, &ch, &accumulator) - } - #[cfg(feature = "paper-exact")] - FoldingMode::OptimizedWithCrosscheck(_) => { - crate::optimized_engine::claimed_initial_sum_from_inputs(&s, &ch, &accumulator) - } - }; - if let Some(x) = step_proof.fold.ccs_proof.sc_initial_sum { - if x != claimed_initial { - return Err(PiCcsError::SumcheckError( - "initial sum mismatch: proof claims different value than public T".into(), - )); - } - } - tr.append_fields(b"sumcheck/initial_sum", &claimed_initial.as_coeffs()); - - // Route A memory checks use a separate transcript-derived cycle point `r_cycle` - // to form χ_{r_cycle}(t) weights inside their sum-check polynomials. - let r_cycle: Vec = - ts::sample_ext_point(tr, b"route_a/r_cycle", b"route_a/cycle/0", b"route_a/cycle/1", ell_n); - - let shout_pre = crate::memory_sidecar::memory::verify_shout_addr_pre_time(tr, step, &step_proof.mem, step_idx)?; - let twist_pre = crate::memory_sidecar::memory::verify_twist_addr_pre_time(tr, step, &step_proof.mem)?; - let crate::memory_sidecar::route_a_time::RouteABatchedTimeVerifyOutput { r_time, final_values } = - crate::memory_sidecar::route_a_time::verify_route_a_batched_time( - tr, - step_idx, - ell_n, - d_sc, - claimed_initial, - step, - &step_proof.batched_time, - ob_inc_total_degree_bound, - )?; - - // CCS proof structure consistency with batched time proof. - let want_rounds_total = ell_n + ell_d; - if step_proof.fold.ccs_proof.sumcheck_rounds.len() != want_rounds_total { - return Err(PiCcsError::InvalidInput(format!( - "step {}: CCS sumcheck_rounds.len()={}, expected {}", - idx, - step_proof.fold.ccs_proof.sumcheck_rounds.len(), - want_rounds_total - ))); - } - if step_proof.fold.ccs_proof.sumcheck_challenges.len() != want_rounds_total { - return Err(PiCcsError::InvalidInput(format!( - "step {}: CCS sumcheck_challenges.len()={}, expected {}", - idx, - step_proof.fold.ccs_proof.sumcheck_challenges.len(), - want_rounds_total - ))); - } - for (round_idx, (a, b)) in step_proof - .fold - .ccs_proof - .sumcheck_rounds - .iter() - .take(ell_n) - .zip(step_proof.batched_time.round_polys[0].iter()) - .enumerate() - { - if a != b { - return Err(PiCcsError::ProtocolError(format!( - "step {}: CCS time round poly mismatch at round {}", - idx, round_idx - ))); - } - } - - if step_proof.fold.ccs_proof.sumcheck_challenges[..ell_n] != r_time { - return Err(PiCcsError::ProtocolError(format!( - "step {}: CCS time challenges mismatch with r_time", - idx - ))); - } - - let expected_k = accumulator.len() + 1; - if step_proof.fold.ccs_out.len() != expected_k { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS returned {} outputs; expected k={}", - idx, - step_proof.fold.ccs_out.len(), - expected_k - ))); - } - if step_proof.fold.ccs_out.is_empty() { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS produced empty ccs_out", - idx - ))); - } - if step_proof.fold.ccs_out[0].r != r_time { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output r != r_time (Route A requires shared r)", - idx - ))); - } - - // Bind Π_CCS outputs to the public MCS instance and carried ME inputs. - // - // - Commitments must match (Π_CCS does not change commitments). - // - `X` must match the digit-decomposition of public `x` for the MCS output. - // - `X` must match the carried ME inputs for subsequent outputs. - { - let out0 = &step_proof.fold.ccs_out[0]; - if out0.c != mcs_inst.c { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[0].c does not match mcs_inst.c", - idx - ))); - } - if out0.m_in != mcs_inst.m_in { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[0].m_in={}, expected {}", - idx, out0.m_in, mcs_inst.m_in - ))); - } - if out0.X.rows() != D || out0.X.cols() != mcs_inst.m_in { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[0].X has shape {}×{}, expected {}×{}", - idx, - out0.X.rows(), - out0.X.cols(), - D, - mcs_inst.m_in - ))); - } - - for (i, inp) in accumulator.iter().enumerate() { - let out = &step_proof.fold.ccs_out[i + 1]; - if out.c != inp.c { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[{}].c does not match accumulator[{}].c", - idx, - i + 1, - i - ))); - } - if out.m_in != inp.m_in { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[{}].m_in={}, expected {}", - idx, - i + 1, - out.m_in, - inp.m_in - ))); - } - if out.X != inp.X { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[{}].X does not match accumulator[{}].X", - idx, - i + 1, - i - ))); - } - } - } - - // Finish CCS Ajtai rounds alone (continuing transcript state after batched rounds). - let ajtai_rounds = &step_proof.fold.ccs_proof.sumcheck_rounds[ell_n..]; - let (ajtai_chals, running_sum, ok) = - verify_sumcheck_rounds_ds(tr, b"ccs/ajtai", step_idx, d_sc, final_values[0], ajtai_rounds); - if !ok { - return Err(PiCcsError::SumcheckError("Π_CCS Ajtai rounds invalid".into())); - } - - // Verify stored sumcheck challenges/final match transcript-derived values. - let mut r_all = r_time.clone(); - r_all.extend_from_slice(&ajtai_chals); - if r_all != step_proof.fold.ccs_proof.sumcheck_challenges { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS sumcheck challenges mismatch", - idx - ))); - } - if running_sum != step_proof.fold.ccs_proof.sumcheck_final { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS sumcheck_final mismatch", - idx - ))); - } - - // Validate ME input r length (required by RHS assembly if k>1). - for (i, me) in accumulator.iter().enumerate() { - if me.r.len() != ell_n { - return Err(PiCcsError::InvalidInput(format!( - "step {}: ME input r length mismatch at accumulator #{}: expected {}, got {}", - idx, - i, - ell_n, - me.r.len() - ))); - } - } - - if step_proof.fold.ccs_proof.variant != crate::optimized_engine::PiCcsProofVariant::SplitNcV1 { - return Err(PiCcsError::ProtocolError("unsupported Π_CCS proof variant".into())); - } - - // FE-only terminal identity. - let rhs_fe = crate::paper_exact_engine::rhs_terminal_identity_fe_paper_exact( - &s, - params, - &ch, - &r_time, - &ajtai_chals, - &step_proof.fold.ccs_out, - accumulator.first().map(|mi| mi.r.as_slice()), - ); - if running_sum != rhs_fe { - return Err(PiCcsError::SumcheckError( - "Π_CCS FE-only terminal identity check failed".into(), - )); - } - - // NC-only sumcheck + terminal identity. - if step_proof.fold.ccs_proof.sumcheck_rounds_nc.is_empty() { - return Err(PiCcsError::InvalidInput( - "Π_CCS SplitNcV1 requires non-empty sumcheck_rounds_nc".into(), - )); - } - if let Some(x) = step_proof.fold.ccs_proof.sc_initial_sum_nc { - if x != K::ZERO { - return Err(PiCcsError::InvalidInput( - "Π_CCS SplitNcV1 requires sc_initial_sum_nc == 0".into(), - )); - } - } - let want_nc_rounds_total = ell_m - .checked_add(ell_d) - .ok_or_else(|| PiCcsError::ProtocolError("ell_m + ell_d overflow".into()))?; - if step_proof.fold.ccs_proof.sumcheck_rounds_nc.len() != want_nc_rounds_total { - return Err(PiCcsError::InvalidInput(format!( - "step {}: Π_CCS NC sumcheck_rounds_nc.len()={}, expected {}", - idx, - step_proof.fold.ccs_proof.sumcheck_rounds_nc.len(), - want_nc_rounds_total - ))); - } - if step_proof.fold.ccs_proof.sumcheck_challenges_nc.len() != want_nc_rounds_total { - return Err(PiCcsError::InvalidInput(format!( - "step {}: Π_CCS NC sumcheck_challenges_nc.len()={}, expected {}", - idx, - step_proof.fold.ccs_proof.sumcheck_challenges_nc.len(), - want_nc_rounds_total - ))); - } - - let (nc_chals, running_sum_nc, ok_nc) = verify_sumcheck_rounds_ds( - tr, - b"ccs/nc", - step_idx, - d_sc, - K::ZERO, - &step_proof.fold.ccs_proof.sumcheck_rounds_nc, - ); - if !ok_nc { - return Err(PiCcsError::SumcheckError("Π_CCS NC rounds invalid".into())); - } - - if nc_chals != step_proof.fold.ccs_proof.sumcheck_challenges_nc { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS NC sumcheck challenges mismatch", - idx - ))); - } - if running_sum_nc != step_proof.fold.ccs_proof.sumcheck_final_nc { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS sumcheck_final_nc mismatch", - idx - ))); - } - - let (s_col_prime, alpha_prime_nc) = nc_chals.split_at(ell_m); - let d_pad = 1usize - .checked_shl(ell_d as u32) - .ok_or_else(|| PiCcsError::ProtocolError("2^ell_d overflow".into()))?; - for (out_idx, out) in step_proof.fold.ccs_out.iter().enumerate() { - if out.r != r_time { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[{out_idx}] r != r_time", - idx - ))); - } - if out.s_col.as_slice() != s_col_prime { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[{out_idx}] s_col mismatch", - idx - ))); - } - if out.y_zcol.len() != d_pad { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[{out_idx}] y_zcol.len()={}, expected {}", - idx, - out.y_zcol.len(), - d_pad - ))); - } - } - - let rhs_nc = crate::paper_exact_engine::rhs_terminal_identity_nc_paper_exact( - params, - &ch, - s_col_prime, - alpha_prime_nc, - &step_proof.fold.ccs_out, - ); - if running_sum_nc != rhs_nc { - return Err(PiCcsError::SumcheckError( - "Π_CCS NC terminal identity check failed".into(), - )); - } - - let observed_digest = tr.digest32(); - if observed_digest != step_proof.fold.ccs_proof.header_digest.as_slice() { - return Err(PiCcsError::ProtocolError("Π_CCS header digest mismatch".into())); - } - let expected_digest: [u8; 32] = step_proof - .fold - .ccs_proof - .header_digest - .as_slice() - .try_into() - .map_err(|_| PiCcsError::ProtocolError("Π_CCS header digest must be 32 bytes".into()))?; - for (out_idx, out) in step_proof.fold.ccs_out.iter().enumerate() { - if out.fold_digest != expected_digest { - return Err(PiCcsError::ProtocolError(format!( - "step {}: Π_CCS output[{out_idx}] fold_digest mismatch", - idx - ))); - } - } - - // Verify mem proofs (shared CPU bus only). - let prev_step = (idx > 0).then(|| &steps[idx - 1]); - let mem_out = crate::memory_sidecar::memory::verify_route_a_memory_step( - tr, - &cpu_bus, - step, - prev_step, - &step_proof.fold.ccs_out[0], - &r_time, - &r_cycle, - &final_values, - &step_proof.batched_time.claimed_sums, - 1, // claim 0 is CCS/time - &step_proof.mem, - &shout_pre, - &twist_pre, - step_idx, - )?; - - let expected_consumed = if include_ob { - final_values - .len() - .checked_sub(1) - .ok_or_else(|| PiCcsError::ProtocolError("missing output binding claim".into()))? - } else { - final_values.len() - }; - if mem_out.claim_idx_end != expected_consumed { - return Err(PiCcsError::ProtocolError(format!( - "step {}: batched claim index mismatch (consumed {}, expected {})", - idx, mem_out.claim_idx_end, expected_consumed - ))); - } - - if include_ob { - let cfg = - ob_cfg.ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but config missing".into()))?; - let ob_state = ob_state - .take() - .ok_or_else(|| PiCcsError::ProtocolError("output sumcheck state missing".into()))?; - - let inc_idx = final_values - .len() - .checked_sub(1) - .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total claim".into()))?; - if step_proof.batched_time.labels.get(inc_idx).copied() != Some(crate::output_binding::OB_INC_TOTAL_LABEL) { - return Err(PiCcsError::ProtocolError("output binding claim not last".into())); - } - - let inc_total_claim = *step_proof - .batched_time - .claimed_sums - .get(inc_idx) - .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total claimed_sum".into()))?; - let inc_total_final = *final_values - .get(inc_idx) - .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total final_value".into()))?; - - let twist_open = mem_out - .twist_time_openings - .get(cfg.mem_idx) - .ok_or_else(|| PiCcsError::ProtocolError("missing twist_time_openings for mem_idx".into()))?; - let inc_terminal = crate::output_binding::inc_terminal_from_time_openings(twist_open, &ob_state.r_prime) - .map_err(|e| PiCcsError::ProtocolError(format!("inc_total terminal mismatch: {e:?}")))?; - if inc_total_final != inc_terminal { - return Err(PiCcsError::ProtocolError("inc_total terminal mismatch".into())); - } - - let mem_inst = step - .mem_insts - .get(cfg.mem_idx) - .ok_or_else(|| PiCcsError::InvalidInput("output binding mem_idx out of range".into()))?; - let expected_k = 1usize - .checked_shl(cfg.num_bits as u32) - .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^num_bits overflow".into()))?; - if mem_inst.k != expected_k { - return Err(PiCcsError::InvalidInput(format!( - "output binding: cfg.num_bits implies k={}, but mem_inst.k={}", - expected_k, mem_inst.k - ))); - } - let ell_addr = mem_inst.twist_layout().lanes[0].ell_addr; - if ell_addr != cfg.num_bits { - return Err(PiCcsError::InvalidInput(format!( - "output binding: cfg.num_bits={}, but twist_layout.ell_addr={}", - cfg.num_bits, ell_addr - ))); - } - let val_init = crate::output_binding::val_init_from_mem_init(&mem_inst.init, mem_inst.k, &ob_state.r_prime) - .map_err(|e| PiCcsError::ProtocolError(format!("MemInit eval failed: {e:?}")))?; - - let val_final_at_r_prime = val_init + inc_total_claim; - let expected_out = ob_state.eq_eval * ob_state.io_mask_eval * (val_final_at_r_prime - ob_state.val_io_eval); - if expected_out != ob_state.output_final { - return Err(PiCcsError::ProtocolError("output binding final check failed".into())); - } - } - - validate_me_batch_invariants(&step_proof.fold.ccs_out, "verify step ccs outputs")?; - validate_me_batch_invariants(&step_proof.mem.cpu_me_claims_val, "verify step memory val outputs")?; - verify_rlc_dec_lane( - RlcLane::Main, - tr, - params, - &s, - &ring, - ell_d, - mixers, - step_idx, - &step_proof.fold.ccs_out, - &step_proof.fold.rlc_rhos, - &step_proof.fold.rlc_parent, - &step_proof.fold.dec_children, - )?; - - accumulator = step_proof.fold.dec_children.clone(); - - // Phase 2: Verify the r_val folding lane for Twist val-eval ME claims. - match ( - step_proof.mem.cpu_me_claims_val.is_empty(), - step_proof.val_fold.as_ref(), - ) { - (true, None) => {} - (true, Some(_)) => { - return Err(PiCcsError::ProtocolError(format!( - "step {}: unexpected val_fold proof (no r_val ME claims)", - idx - ))); - } - (false, None) => { - return Err(PiCcsError::ProtocolError(format!( - "step {}: missing val_fold proof (have r_val ME claims)", - idx - ))); - } - (false, Some(val_fold)) => { - tr.append_message(b"fold/val_lane_start", &(step_idx as u64).to_le_bytes()); - verify_rlc_dec_lane( - RlcLane::Val, - tr, - params, - &s, - &ring, - ell_d, - mixers, - step_idx, - &step_proof.mem.cpu_me_claims_val, - &val_fold.rlc_rhos, - &val_fold.rlc_parent, - &val_fold.dec_children, - )?; - - val_lane_obligations.extend_from_slice(&val_fold.dec_children); - } - } - - tr.append_message(b"fold/step_done", &(step_idx as u64).to_le_bytes()); - } - - Ok(ShardFoldOutputs { - obligations: ShardObligations { - main: accumulator, - val: val_lane_obligations, - }, - }) -} - -pub fn fold_shard_verify( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - fold_shard_verify_impl(mode, tr, params, s_me, steps, 0, acc_init, proof, mixers, None, None) -} - -/// Same as `fold_shard_verify`, but offsets the per-step transcript index by `step_idx_offset`. -pub fn fold_shard_verify_with_step_offset( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - step_idx_offset: usize, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - fold_shard_verify_impl( - mode, - tr, - params, - s_me, - steps, - step_idx_offset, - acc_init, - proof, - mixers, - None, - None, - ) -} - -pub fn fold_shard_verify_with_step_linking( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - step_linking: &StepLinkingConfig, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - check_step_linking(steps, step_linking)?; - fold_shard_verify(mode, tr, params, s_me, steps, acc_init, proof, mixers) -} - -pub fn fold_shard_verify_with_output_binding( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - fold_shard_verify_impl( - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - proof, - mixers, - Some(ob_cfg), - None, - ) -} - -pub(crate) fn fold_shard_verify_with_context( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - prover_ctx: &ShardProverContext, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - fold_shard_verify_impl( - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - proof, - mixers, - None, - Some(prover_ctx), - ) -} - -pub(crate) fn fold_shard_verify_with_step_linking_with_context( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - step_linking: &StepLinkingConfig, - prover_ctx: &ShardProverContext, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - check_step_linking(steps, step_linking)?; - fold_shard_verify_with_context(mode, tr, params, s_me, steps, acc_init, proof, mixers, prover_ctx) -} - -pub(crate) fn fold_shard_verify_with_output_binding_with_context( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, - prover_ctx: &ShardProverContext, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - fold_shard_verify_impl( - mode, - tr, - params, - s_me, - steps, - 0, - acc_init, - proof, - mixers, - Some(ob_cfg), - Some(prover_ctx), - ) -} - -pub(crate) fn fold_shard_verify_with_output_binding_and_step_linking_with_context( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, - step_linking: &StepLinkingConfig, - prover_ctx: &ShardProverContext, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - check_step_linking(steps, step_linking)?; - fold_shard_verify_with_output_binding_with_context( - mode, tr, params, s_me, steps, acc_init, proof, mixers, ob_cfg, prover_ctx, - ) -} - -pub fn fold_shard_verify_with_output_binding_and_step_linking( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, - step_linking: &StepLinkingConfig, -) -> Result, PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, -{ - check_step_linking(steps, step_linking)?; - fold_shard_verify_with_output_binding(mode, tr, params, s_me, steps, acc_init, proof, mixers, ob_cfg) -} - -pub fn fold_shard_verify_and_finalize( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - finalizer: &mut Fin, -) -> Result<(), PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, - Fin: ObligationFinalizer, -{ - let outputs = fold_shard_verify(mode, tr, params, s_me, steps, acc_init, proof, mixers)?; - let report = finalizer.finalize(&outputs.obligations)?; - outputs - .obligations - .require_all_finalized(report.did_finalize_main, report.did_finalize_val)?; - Ok(()) -} - -pub fn fold_shard_verify_and_finalize_with_step_linking( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - step_linking: &StepLinkingConfig, - finalizer: &mut Fin, -) -> Result<(), PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, - Fin: ObligationFinalizer, -{ - check_step_linking(steps, step_linking)?; - fold_shard_verify_and_finalize(mode, tr, params, s_me, steps, acc_init, proof, mixers, finalizer) -} - -pub fn fold_shard_verify_and_finalize_with_output_binding( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, - finalizer: &mut Fin, -) -> Result<(), PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, - Fin: ObligationFinalizer, -{ - let outputs = - fold_shard_verify_with_output_binding(mode, tr, params, s_me, steps, acc_init, proof, mixers, ob_cfg)?; - let report = finalizer.finalize(&outputs.obligations)?; - outputs - .obligations - .require_all_finalized(report.did_finalize_main, report.did_finalize_val)?; - Ok(()) -} - -pub fn fold_shard_verify_and_finalize_with_output_binding_and_step_linking( - mode: FoldingMode, - tr: &mut Poseidon2Transcript, - params: &NeoParams, - s_me: &CcsStructure, - steps: &[StepInstanceBundle], - acc_init: &[MeInstance], - proof: &ShardProof, - mixers: CommitMixers, - ob_cfg: &crate::output_binding::OutputBindingConfig, - step_linking: &StepLinkingConfig, - finalizer: &mut Fin, -) -> Result<(), PiCcsError> -where - MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, - MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, - Fin: ObligationFinalizer, -{ - check_step_linking(steps, step_linking)?; - fold_shard_verify_and_finalize_with_output_binding( - mode, tr, params, s_me, steps, acc_init, proof, mixers, ob_cfg, finalizer, - ) -} +pub(crate) use core_utils::*; +pub(crate) use prover::*; +pub(crate) use rlc_dec::*; diff --git a/crates/neo-fold/src/shard/core_utils.rs b/crates/neo-fold/src/shard/core_utils.rs new file mode 100644 index 00000000..e4ac07de --- /dev/null +++ b/crates/neo-fold/src/shard/core_utils.rs @@ -0,0 +1,1339 @@ +use super::*; + +pub(crate) enum CcsOracleDispatch<'a> { + Optimized(neo_reductions::engines::optimized_engine::oracle::OptimizedOracle<'a, F>), + #[cfg(feature = "paper-exact")] + PaperExact(neo_reductions::engines::paper_exact_engine::oracle::PaperExactOracle<'a, F>), +} + +impl<'a> RoundOracle for CcsOracleDispatch<'a> { + fn evals_at(&mut self, points: &[K]) -> Vec { + match self { + Self::Optimized(oracle) => oracle.evals_at(points), + #[cfg(feature = "paper-exact")] + Self::PaperExact(oracle) => oracle.evals_at(points), + } + } + + fn num_rounds(&self) -> usize { + match self { + Self::Optimized(oracle) => oracle.num_rounds(), + #[cfg(feature = "paper-exact")] + Self::PaperExact(oracle) => oracle.num_rounds(), + } + } + + fn degree_bound(&self) -> usize { + match self { + Self::Optimized(oracle) => oracle.degree_bound(), + #[cfg(feature = "paper-exact")] + Self::PaperExact(oracle) => oracle.degree_bound(), + } + } + + fn fold(&mut self, r: K) { + match self { + Self::Optimized(oracle) => oracle.fold(r), + #[cfg(feature = "paper-exact")] + Self::PaperExact(oracle) => oracle.fold(r), + } + } +} + +// ============================================================================ +// Utilities +// ============================================================================ + +pub use crate::memory_sidecar::memory::absorb_step_memory; + +// ============================================================================ +// Optional step-to-step (cross-chunk) linking +// ============================================================================ + +/// Optional verifier-side linking constraints across adjacent shard steps. +/// +/// This is intended for chunked CPU circuits that expose boundary state as part of the public +/// input vector `x` per step, and need the verifier to enforce that the state chains across steps. +#[derive(Clone, Debug)] +pub struct StepLinkingConfig { + /// Equalities on adjacent steps: require `steps[i].x[prev_idx] == steps[i+1].x[next_idx]`. + pub prev_next_equalities: Vec<(usize, usize)>, +} + +impl StepLinkingConfig { + pub fn new(prev_next_equalities: Vec<(usize, usize)>) -> Self { + Self { prev_next_equalities } + } +} + +pub fn check_step_linking(steps: &[StepInstanceBundle], cfg: &StepLinkingConfig) -> Result<(), PiCcsError> { + if steps.len() <= 1 || cfg.prev_next_equalities.is_empty() { + return Ok(()); + } + for (i, (prev, next)) in steps.iter().zip(steps.iter().skip(1)).enumerate() { + let prev_x = &prev.mcs_inst.x; + let next_x = &next.mcs_inst.x; + for &(prev_idx, next_idx) in &cfg.prev_next_equalities { + if prev_idx >= prev_x.len() || next_idx >= next_x.len() { + return Err(PiCcsError::InvalidInput(format!( + "step linking index out of range at boundary {i}: prev_x.len()={}, next_x.len()={}, pair=({prev_idx},{next_idx})", + prev_x.len(), + next_x.len(), + ))); + } + if prev_x[prev_idx] != next_x[next_idx] { + return Err(PiCcsError::ProtocolError(format!( + "step linking failed at boundary {i}: prev_x[{prev_idx}] != next_x[{next_idx}]", + ))); + } + } + } + Ok(()) +} + +/// Commitment mixers so the coordinator stays scheme-agnostic. +/// - `mix_rhos_commits(ρ, cs)` returns Σ ρ_i · c_i (S-action). +/// - `combine_b_pows(cs, b)` returns Σ \bar b^{i-1} c_i (DEC check). +#[derive(Clone, Copy)] +pub struct CommitMixers +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt, + MB: Fn(&[Cmt], u32) -> Cmt, +{ + pub mix_rhos_commits: MR, + pub combine_b_pows: MB, +} + +pub fn normalize_me_claims( + me_claims: &mut [MeInstance], + ell_n: usize, + ell_d: usize, + t: usize, +) -> Result<(), PiCcsError> { + let y_pad = 1usize << ell_d; + for (i, me) in me_claims.iter_mut().enumerate() { + if me.r.len() != ell_n { + return Err(PiCcsError::InvalidInput(format!( + "ME[{}] r.len()={}, expected ell_n={}", + i, + me.r.len(), + ell_n + ))); + } + if me.y.len() > t { + return Err(PiCcsError::InvalidInput(format!( + "ME[{}] y.len()={}, expected <= t={}", + i, + me.y.len(), + t + ))); + } + for (j, row) in me.y.iter_mut().enumerate() { + if row.len() > y_pad { + return Err(PiCcsError::InvalidInput(format!( + "ME[{}] y[{}].len()={}, expected <= {}", + i, + j, + row.len(), + y_pad + ))); + } + row.resize(y_pad, K::ZERO); + } + me.y.resize_with(t, || vec![K::ZERO; y_pad]); + if me.y_scalars.len() > t { + return Err(PiCcsError::InvalidInput(format!( + "ME[{}] y_scalars.len()={}, expected <= t={}", + i, + me.y_scalars.len(), + t + ))); + } + me.y_scalars.resize(t, K::ZERO); + } + Ok(()) +} + +pub(crate) fn validate_me_batch_invariants(batch: &[MeInstance], context: &str) -> Result<(), PiCcsError> { + if batch.is_empty() { + return Ok(()); + } + let me0 = &batch[0]; + let r0 = &me0.r; + let m_in0 = me0.m_in; + let y_len0 = me0.y.len(); + let y_row_len0 = me0.y.first().map(|r| r.len()).unwrap_or(0); + let y_scalars_len0 = me0.y_scalars.len(); + + if me0.X.rows() != D { + return Err(PiCcsError::ProtocolError(format!( + "{}: ME claim 0 has X.rows()={}, expected D={}", + context, + me0.X.rows(), + D + ))); + } + if me0.X.cols() != m_in0 { + return Err(PiCcsError::ProtocolError(format!( + "{}: ME claim 0 has X.cols()={}, expected m_in={}", + context, + me0.X.cols(), + m_in0 + ))); + } + + for (i, me) in batch.iter().enumerate().skip(1) { + if me.r != *r0 { + return Err(PiCcsError::ProtocolError(format!( + "{}: ME claim {} has different r than claim 0 (r-alignment required for RLC)", + context, i + ))); + } + if me.m_in != m_in0 { + return Err(PiCcsError::ProtocolError(format!( + "{}: ME claim {} has m_in={}, expected {}", + context, i, me.m_in, m_in0 + ))); + } + if me.X.rows() != D || me.X.cols() != m_in0 { + return Err(PiCcsError::ProtocolError(format!( + "{}: ME claim {} has X shape {}x{}, expected {}x{}", + context, + i, + me.X.rows(), + me.X.cols(), + D, + m_in0 + ))); + } + if me.y.len() != y_len0 { + return Err(PiCcsError::ProtocolError(format!( + "{}: ME claim {} has y.len()={}, expected {}", + context, + i, + me.y.len(), + y_len0 + ))); + } + for (j, row) in me.y.iter().enumerate() { + if row.len() != y_row_len0 { + return Err(PiCcsError::ProtocolError(format!( + "{}: ME claim {} has y[{}].len()={}, expected {}", + context, + i, + j, + row.len(), + y_row_len0 + ))); + } + } + if me.y_scalars.len() != y_scalars_len0 { + return Err(PiCcsError::ProtocolError(format!( + "{}: ME claim {} has y_scalars.len()={}, expected {}", + context, + i, + me.y_scalars.len(), + y_scalars_len0 + ))); + } + } + Ok(()) +} + +#[derive(Clone, Copy, Debug)] +pub(crate) enum RlcLane { + Main, + Val, +} + +#[inline] +pub(crate) fn balanced_divrem_i64(v: i64, b: i64) -> (i64, i64) { + debug_assert!(b >= 2); + let mut r = v % b; + let mut q = (v - r) / b; + let half = b / 2; + if r > half { + r -= b; + q += 1; + } else if r < -half { + r += b; + q -= 1; + } + (r, q) +} + +#[inline] +pub(crate) fn balanced_divrem_i128(v: i128, b: i128) -> (i128, i128) { + debug_assert!(b >= 2); + let mut r = v % b; + let mut q = (v - r) / b; + let half = b / 2; + if r > half { + r -= b; + q += 1; + } else if r < -half { + r += b; + q -= 1; + } + (r, q) +} + +#[inline] +pub(crate) fn f_from_i64(x: i64) -> F { + if x >= 0 { + F::from_u64(x as u64) + } else { + F::ZERO - F::from_u64((-x) as u64) + } +} + +#[inline] +pub(crate) fn verify_me_y_scalars_canonical( + me: &MeInstance, + b: u32, + step_idx: usize, + context: &str, +) -> Result<(), PiCcsError> { + if me.y_scalars.len() != me.y.len() { + return Err(PiCcsError::InvalidInput(format!( + "step {}: {}: y_scalars.len()={} must equal y.len()={}", + step_idx, + context, + me.y_scalars.len(), + me.y.len() + ))); + } + let bK = K::from(F::from_u64(b as u64)); + for (j, row) in me.y.iter().enumerate() { + if row.len() < D { + return Err(PiCcsError::InvalidInput(format!( + "step {}: {}: y[{}].len()={} must be >= D={}", + step_idx, + context, + j, + row.len(), + D + ))); + } + let mut expect = K::ZERO; + let mut pow = K::ONE; + for rho in 0..D { + expect += pow * row[rho]; + pow *= bK; + } + if me.y_scalars[j] != expect { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {}: non-canonical y_scalars at row {}", + step_idx, context, j + ))); + } + } + Ok(()) +} + +pub(crate) fn dec_stream_no_witness( + params: &NeoParams, + s: &CcsStructure, + parent: &MeInstance, + Z_mix: &Mat, + ell_d: usize, + k_dec: usize, + combine_b_pows: MB, + sparse: Option<&SparseCache>, +) -> Result<(Vec>, Vec, bool, bool, bool), PiCcsError> +where + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + if k_dec == 0 { + return Err(PiCcsError::InvalidInput("DEC: k_dec must be > 0".into())); + } + if Z_mix.rows() != D || Z_mix.cols() != s.m { + return Err(PiCcsError::InvalidInput(format!( + "DEC: Z_mix must have shape D×m = {}×{} (got {}×{})", + D, + s.m, + Z_mix.rows(), + Z_mix.cols() + ))); + } + + let d_pad = 1usize << ell_d; + let want_nc_channel = !(parent.s_col.is_empty() && parent.y_zcol.is_empty()); + if want_nc_channel && (parent.s_col.is_empty() || parent.y_zcol.is_empty()) { + return Err(PiCcsError::InvalidInput( + "DEC: incomplete NC channel on parent (expected both s_col and y_zcol)".into(), + )); + } + if want_nc_channel && parent.y_zcol.len() != d_pad { + return Err(PiCcsError::InvalidInput(format!( + "DEC: parent y_zcol length mismatch (expected {}, got {})", + d_pad, + parent.y_zcol.len() + ))); + } + + enum PpAccess { + Seeded { + kappa: usize, + chunk_size: usize, + chunk_seeds_by_row: Vec>, + }, + Loaded { + pp: Arc>, + }, + } + + let pp_access = if let Some(pp) = try_get_loaded_global_pp_for_dims(D, s.m) { + if pp.kappa == 0 { + return Err(PiCcsError::InvalidInput("DEC: PP.kappa must be > 0".into())); + } + PpAccess::Loaded { pp } + } else if let Ok((kappa, seed)) = get_global_pp_seeded_params_for_dims(D, s.m) { + if kappa == 0 { + return Err(PiCcsError::InvalidInput("DEC: PP.kappa must be > 0".into())); + } + let (chunk_size, chunk_seeds_by_row) = seeded_pp_chunk_seeds(seed, kappa, s.m); + PpAccess::Seeded { + kappa, + chunk_size, + chunk_seeds_by_row, + } + } else { + // Fallback: non-seeded entry. This will materialize PP if needed. + let pp = get_global_pp_for_dims(D, s.m).map_err(|e| { + PiCcsError::InvalidInput(format!("DEC: Ajtai PP unavailable for (d,m)=({},{}) ({})", D, s.m, e)) + })?; + if pp.kappa == 0 { + return Err(PiCcsError::InvalidInput("DEC: PP.kappa must be > 0".into())); + } + PpAccess::Loaded { pp } + }; + + // Build χ_r and v_j = M_j^T · χ_r (same as the reference DEC). + let ell_n = parent.r.len(); + let n_sz = 1usize + .checked_shl(ell_n as u32) + .ok_or_else(|| PiCcsError::InvalidInput("DEC: 2^ell_n overflow".into()))?; + let n_eff = core::cmp::min(s.n, n_sz); + + // χ_r table over the row/time hypercube. + // + // IMPORTANT: Use the same bit order as `eq_points_bool_mask` / `chi_tail_weights` + // (bit 0 = LSB) so CSC column traversals match the reference DEC. + #[inline] + fn chi_tail_weights(bits: &[K]) -> Vec { + let t = bits.len(); + let len = 1usize << t; + let mut w = vec![K::ZERO; len]; + w[0] = K::ONE; + for (i, &b) in bits.iter().enumerate() { + let step = 1usize << i; + let one_minus = K::ONE - b; + for mask in 0..step { + let v = w[mask]; + w[mask] = v * one_minus; + w[mask + step] = v * b; + } + } + w + } + + let chi_r = chi_tail_weights(&parent.r); + debug_assert_eq!(chi_r.len(), n_sz); + + let chi_s = if want_nc_channel { + let chi = chi_tail_weights(&parent.s_col); + if chi.len() < s.m { + return Err(PiCcsError::InvalidInput(format!( + "DEC: chi(s_col) too short for CCS width (need >= {}, got {})", + s.m, + chi.len() + ))); + } + chi + } else { + Vec::new() + }; + + let t_mats = s.t(); + + enum VjsAccess<'a> { + Dense(Vec>), + Sparse { + cap: usize, + cache: &'a SparseCache, + }, + } + + let vjs_access = if let Some(cache) = sparse { + if cache.len() != t_mats { + return Err(PiCcsError::InvalidInput(format!( + "DEC: sparse cache matrix count mismatch: got {}, expected {}", + cache.len(), + t_mats + ))); + } + let cap = core::cmp::min(s.m, n_eff); + VjsAccess::Sparse { cap, cache } + } else { + let mut vjs: Vec> = vec![vec![K::ZERO; s.m]; t_mats]; + for j in 0..t_mats { + s.matrices[j].add_mul_transpose_into(&chi_r, &mut vjs[j], n_eff); + } + VjsAccess::Dense(vjs) + }; + + // Base-b powers in K for y_scalar recomposition. + let bF = F::from_u64(params.b as u64); + let bK = K::from(bF); + let mut pow_b_k = [K::ONE; D]; + for rho in 1..D { + pow_b_k[rho] = pow_b_k[rho - 1] * bK; + } + + // Precompute parameters for bounded signed decoding of Z_mix entries. + let b_u = params.b as u128; + let mut B_u: u128 = 1; + for _ in 0..k_dec { + B_u = B_u.saturating_mul(b_u); + } + let p: u128 = F::ORDER_U64 as u128; + + // Fast row-major access. + let z_rows: Vec<&[F]> = (0..D).map(|r| Z_mix.row(r)).collect(); + + struct Acc { + commit: Vec<[F; D]>, // [digit][kappa] -> [D] + y: Vec<[K; D]>, // [digit][t] -> [D] + y_zcol: Vec<[K; D]>, // [digit] -> [D] + any_nonzero: Vec, + vj: Vec, // scratch: t + digits: Vec, // scratch: k*D (balanced digits) + rot_next: [F; D], // scratch: rotation step output (written fully each time) + err: Option, // first error wins + } + + impl Acc { + fn new(k_dec: usize, kappa: usize, t: usize) -> Self { + Self { + commit: vec![[F::ZERO; D]; k_dec * kappa], + y: vec![[K::ZERO; D]; k_dec * t], + y_zcol: vec![[K::ZERO; D]; k_dec], + any_nonzero: vec![false; k_dec], + vj: vec![K::ZERO; t], + digits: vec![0i32; k_dec * D], + rot_next: [F::ZERO; D], + err: None, + } + } + + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + fn add_inplace(&mut self, rhs: &Acc, k_dec: usize, kappa: usize, t: usize) { + for (dst, src) in self.commit.iter_mut().zip(rhs.commit.iter()) { + for r in 0..D { + dst[r] += src[r]; + } + } + for (dst, src) in self.y.iter_mut().zip(rhs.y.iter()) { + for r in 0..D { + dst[r] += src[r]; + } + } + for (dst, src) in self.y_zcol.iter_mut().zip(rhs.y_zcol.iter()) { + for r in 0..D { + dst[r] += src[r]; + } + } + for i in 0..k_dec { + self.any_nonzero[i] |= rhs.any_nonzero[i]; + } + if self.err.is_none() { + self.err = rhs.err.clone(); + } + // silence unused warnings when parameters are const-propagated + let _ = (k_dec, kappa, t); + } + } + + let m = s.m; + let b_i64 = params.b as i64; + let b_i128 = params.b as i128; + + // Specialized rot_step for Φ₈₁(X) = X^54 + X^27 + 1 (η=81, D=54). + // Mirrors `neo_ajtai::commit::rot_step_phi_81` but kept local to avoid pulling a large + // D×D scratch table (`precompute_rot_columns`) into the hot DEC streaming loop. + #[inline] + fn rot_step_phi_81(cur: &[F; D], next: &mut [F; D]) { + let last = cur[D - 1]; + next[0] = F::ZERO; + next[1..D].copy_from_slice(&cur[..(D - 1)]); + next[0] -= last; + next[27] -= last; + } + + #[inline] + fn acc_add_assign(acc: &mut [F; D], col: &[F; D]) { + type P = ::Packing; + let prefix_len = D - (D % P::WIDTH); + let (acc_prefix, acc_suffix) = acc.split_at_mut(prefix_len); + let (col_prefix, col_suffix) = col.split_at(prefix_len); + + for (a, b) in P::pack_slice_mut(acc_prefix) + .iter_mut() + .zip(P::pack_slice(col_prefix).iter()) + { + *a += *b; + } + for (a, &b) in acc_suffix.iter_mut().zip(col_suffix.iter()) { + *a += b; + } + } + + #[inline] + fn acc_sub_assign(acc: &mut [F; D], col: &[F; D]) { + type P = ::Packing; + let prefix_len = D - (D % P::WIDTH); + let (acc_prefix, acc_suffix) = acc.split_at_mut(prefix_len); + let (col_prefix, col_suffix) = col.split_at(prefix_len); + + for (a, b) in P::pack_slice_mut(acc_prefix) + .iter_mut() + .zip(P::pack_slice(col_prefix).iter()) + { + *a -= *b; + } + for (a, &b) in acc_suffix.iter_mut().zip(col_suffix.iter()) { + *a -= b; + } + } + + #[inline] + fn acc_mul_add_assign(acc: &mut [F; D], col: &[F; D], scalar: F) { + type P = ::Packing; + let prefix_len = D - (D % P::WIDTH); + let (acc_prefix, acc_suffix) = acc.split_at_mut(prefix_len); + let (col_prefix, col_suffix) = col.split_at(prefix_len); + let scalar_p: P = scalar.into(); + + for (a, b) in P::pack_slice_mut(acc_prefix) + .iter_mut() + .zip(P::pack_slice(col_prefix).iter()) + { + *a += *b * scalar_p; + } + for (a, &b) in acc_suffix.iter_mut().zip(col_suffix.iter()) { + *a += b * scalar; + } + } + + let (kappa, acc) = match &pp_access { + PpAccess::Loaded { pp } => { + let kappa = pp.kappa; + let process_col = |mut st: Acc, col: usize| -> Acc { + if st.err.is_some() { + return st; + } + + // Decompose the column's D entries into balanced base-b digits for each DEC child. + for rho in 0..D { + let u = z_rows[rho][col].as_canonical_u64() as u128; + if B_u <= i64::MAX as u128 { + let val_opt: Option = if u < B_u { + Some(u as i64) + } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { + Some(-((p - u) as i64)) + } else { + None + }; + let mut v = match val_opt { + Some(v) => v, + None => { + st.err = Some(format!( + "DEC split: Z_mix[{},{}] is out of range for k_rho={}, b={}", + rho, col, k_dec, params.b + )); + return st; + } + }; + for i in 0..k_dec { + if v == 0 { + st.digits[i * D + rho] = 0; + continue; + } + let (r_i, q) = balanced_divrem_i64(v, b_i64); + if r_i != 0 { + st.any_nonzero[i] = true; + } + st.digits[i * D + rho] = r_i as i32; + v = q; + } + if v != 0 { + st.err = Some(format!( + "DEC split: Z_mix[{},{}] needs more than k_rho={} digits in base b={}", + rho, col, k_dec, params.b + )); + return st; + } + } else { + let val_opt: Option = if u < B_u { + Some(u as i128) + } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { + Some(-((p - u) as i128)) + } else { + None + }; + let mut v = match val_opt { + Some(v) => v, + None => { + st.err = Some(format!( + "DEC split: Z_mix[{},{}] is out of range for k_rho={}, b={}", + rho, col, k_dec, params.b + )); + return st; + } + }; + for i in 0..k_dec { + if v == 0 { + st.digits[i * D + rho] = 0; + continue; + } + let (r_i, q) = balanced_divrem_i128(v, b_i128); + if r_i != 0 { + st.any_nonzero[i] = true; + } + st.digits[i * D + rho] = r_i as i32; + v = q; + } + if v != 0 { + st.err = Some(format!( + "DEC split: Z_mix[{},{}] needs more than k_rho={} digits in base b={}", + rho, col, k_dec, params.b + )); + return st; + } + } + } + + // vj[col] := M_j^T · χ_r (compute per column to avoid materializing all vjs). + match &vjs_access { + VjsAccess::Dense(vjs) => { + for j in 0..t_mats { + st.vj[j] = vjs[j][col]; + } + } + VjsAccess::Sparse { cap, cache } => { + for j in 0..t_mats { + st.vj[j] = if let Some(csc) = cache.csc(j) { + let mut sum = K::ZERO; + let s = csc.col_ptr[col]; + let e = csc.col_ptr[col + 1]; + for k in s..e { + let r = csc.row_idx[k]; + if r < n_eff { + sum += K::from(csc.vals[k]) * chi_r[r]; + } + } + sum + } else if col < *cap { + chi_r[col] + } else { + K::ZERO + }; + } + } + } + + // y_(i,j)[rho] += Z_i[rho,col] * vj[col] + for i in 0..k_dec { + let y_base = i * t_mats; + for rho in 0..D { + let digit = st.digits[i * D + rho]; + if digit == 0 { + continue; + } + for j in 0..t_mats { + let vj = st.vj[j]; + if vj != K::ZERO { + match digit { + 1 => st.y[y_base + j][rho] += vj, + -1 => st.y[y_base + j][rho] -= vj, + _ => st.y[y_base + j][rho] += vj.scale_base(f_from_i64(digit as i64)), + } + } + } + } + } + + // y_zcol_i[rho] += Z_i[rho,col] * χ_{s_col}[col] (optional). + if !chi_s.is_empty() { + let w_col = chi_s[col]; + if w_col != K::ZERO { + for i in 0..k_dec { + for rho in 0..D { + let digit = st.digits[i * D + rho]; + if digit == 0 { + continue; + } + match digit { + 1 => st.y_zcol[i][rho] += w_col, + -1 => st.y_zcol[i][rho] -= w_col, + _ => st.y_zcol[i][rho] += w_col.scale_base(f_from_i64(digit as i64)), + } + } + } + } + } + + // Commitment accumulators per digit. + for kr in 0..kappa { + let mut rot_col = neo_math::ring::cf(pp.m_rows[kr][col]); + for rho in 0..D { + for i in 0..k_dec { + let digit = st.digits[i * D + rho]; + if digit == 0 { + continue; + } + let acc = &mut st.commit[i * kappa + kr]; + match digit { + 1 => acc_add_assign(acc, &rot_col), + -1 => acc_sub_assign(acc, &rot_col), + _ => acc_mul_add_assign(acc, &rot_col, f_from_i64(digit as i64)), + } + } + rot_step_phi_81(&rot_col, &mut st.rot_next); + core::mem::swap(&mut rot_col, &mut st.rot_next); + } + } + + st + }; + + let acc = { + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + (0..m) + .into_par_iter() + .fold(|| Acc::new(k_dec, kappa, t_mats), |st, col| process_col(st, col)) + .reduce( + || Acc::new(k_dec, kappa, t_mats), + |mut a, b| { + if a.err.is_none() { + a.add_inplace(&b, k_dec, kappa, t_mats); + } + a + }, + ) + } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + let mut st = Acc::new(k_dec, kappa, t_mats); + for col in 0..m { + st = process_col(st, col); + } + st + } + }; + (kappa, acc) + } + PpAccess::Seeded { + kappa, + chunk_size, + chunk_seeds_by_row, + } => { + let kappa = *kappa; + let chunk_size = *chunk_size; + let num_chunks = (m + chunk_size - 1) / chunk_size; + + let process_chunk = |mut st: Acc, chunk_idx: usize| -> Acc { + if st.err.is_some() { + return st; + } + + let start = chunk_idx * chunk_size; + let end = core::cmp::min(m, start + chunk_size); + + let mut rngs: Vec = (0..kappa) + .map(|kr| ChaCha8Rng::from_seed(chunk_seeds_by_row[kr][chunk_idx])) + .collect(); + + for col in start..end { + // Decompose the column's D entries into balanced base-b digits for each DEC child. + for rho in 0..D { + let u = z_rows[rho][col].as_canonical_u64() as u128; + if B_u <= i64::MAX as u128 { + let val_opt: Option = if u < B_u { + Some(u as i64) + } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { + Some(-((p - u) as i64)) + } else { + None + }; + let mut v = match val_opt { + Some(v) => v, + None => { + st.err = Some(format!( + "DEC split: Z_mix[{},{}] is out of range for k_rho={}, b={}", + rho, col, k_dec, params.b + )); + return st; + } + }; + for i in 0..k_dec { + if v == 0 { + st.digits[i * D + rho] = 0; + continue; + } + let (r_i, q) = balanced_divrem_i64(v, b_i64); + if r_i != 0 { + st.any_nonzero[i] = true; + } + st.digits[i * D + rho] = r_i as i32; + v = q; + } + if v != 0 { + st.err = Some(format!( + "DEC split: Z_mix[{},{}] needs more than k_rho={} digits in base b={}", + rho, col, k_dec, params.b + )); + return st; + } + } else { + let val_opt: Option = if u < B_u { + Some(u as i128) + } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { + Some(-((p - u) as i128)) + } else { + None + }; + let mut v = match val_opt { + Some(v) => v, + None => { + st.err = Some(format!( + "DEC split: Z_mix[{},{}] is out of range for k_rho={}, b={}", + rho, col, k_dec, params.b + )); + return st; + } + }; + for i in 0..k_dec { + if v == 0 { + st.digits[i * D + rho] = 0; + continue; + } + let (r_i, q) = balanced_divrem_i128(v, b_i128); + if r_i != 0 { + st.any_nonzero[i] = true; + } + st.digits[i * D + rho] = r_i as i32; + v = q; + } + if v != 0 { + st.err = Some(format!( + "DEC split: Z_mix[{},{}] needs more than k_rho={} digits in base b={}", + rho, col, k_dec, params.b + )); + return st; + } + } + } + + // vj[col] := M_j^T · χ_r (compute per column to avoid materializing all vjs). + match &vjs_access { + VjsAccess::Dense(vjs) => { + for j in 0..t_mats { + st.vj[j] = vjs[j][col]; + } + } + VjsAccess::Sparse { cap, cache } => { + for j in 0..t_mats { + st.vj[j] = if let Some(csc) = cache.csc(j) { + let mut sum = K::ZERO; + let s = csc.col_ptr[col]; + let e = csc.col_ptr[col + 1]; + for k in s..e { + let r = csc.row_idx[k]; + if r < n_eff { + sum += K::from(csc.vals[k]) * chi_r[r]; + } + } + sum + } else if col < *cap { + chi_r[col] + } else { + K::ZERO + }; + } + } + } + + // y_(i,j)[rho] += Z_i[rho,col] * vj[col] + for i in 0..k_dec { + let y_base = i * t_mats; + for rho in 0..D { + let digit = st.digits[i * D + rho]; + if digit == 0 { + continue; + } + for j in 0..t_mats { + let vj = st.vj[j]; + if vj != K::ZERO { + match digit { + 1 => st.y[y_base + j][rho] += vj, + -1 => st.y[y_base + j][rho] -= vj, + _ => st.y[y_base + j][rho] += vj.scale_base(f_from_i64(digit as i64)), + } + } + } + } + } + + // y_zcol_i[rho] += Z_i[rho,col] * χ_{s_col}[col] (optional). + if !chi_s.is_empty() { + let w_col = chi_s[col]; + if w_col != K::ZERO { + for i in 0..k_dec { + for rho in 0..D { + let digit = st.digits[i * D + rho]; + if digit == 0 { + continue; + } + match digit { + 1 => st.y_zcol[i][rho] += w_col, + -1 => st.y_zcol[i][rho] -= w_col, + _ => st.y_zcol[i][rho] += w_col.scale_base(f_from_i64(digit as i64)), + } + } + } + } + } + + // Commitment accumulators per digit. + for kr in 0..kappa { + let a_kr_col = sample_uniform_rq(&mut rngs[kr]); + let mut rot_col = neo_math::ring::cf(a_kr_col); + for rho in 0..D { + for i in 0..k_dec { + let digit = st.digits[i * D + rho]; + if digit == 0 { + continue; + } + let acc = &mut st.commit[i * kappa + kr]; + match digit { + 1 => acc_add_assign(acc, &rot_col), + -1 => acc_sub_assign(acc, &rot_col), + _ => acc_mul_add_assign(acc, &rot_col, f_from_i64(digit as i64)), + } + } + rot_step_phi_81(&rot_col, &mut st.rot_next); + core::mem::swap(&mut rot_col, &mut st.rot_next); + } + } + } + + st + }; + + let acc = { + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + (0..num_chunks) + .into_par_iter() + .fold( + || Acc::new(k_dec, kappa, t_mats), + |st, chunk_idx| process_chunk(st, chunk_idx), + ) + .reduce( + || Acc::new(k_dec, kappa, t_mats), + |mut a, b| { + if a.err.is_none() { + a.add_inplace(&b, k_dec, kappa, t_mats); + } + a + }, + ) + } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + let mut st = Acc::new(k_dec, kappa, t_mats); + for chunk_idx in 0..num_chunks { + st = process_chunk(st, chunk_idx); + } + st + } + }; + (kappa, acc) + } + }; + + if let Some(err) = acc.err { + return Err(PiCcsError::ProtocolError(err)); + } + + // Commitments c_i from accumulated columns. + let mut child_cs: Vec = Vec::with_capacity(k_dec); + for i in 0..k_dec { + if !acc.any_nonzero[i] { + child_cs.push(Cmt::zeros(D, kappa)); + continue; + } + let mut c = Cmt::zeros(D, kappa); + for kr in 0..kappa { + c.col_mut(kr).copy_from_slice(&acc.commit[i * kappa + kr]); + } + child_cs.push(c); + } + + // X_i: project first m_in columns from Z_i (small; compute sequentially). + let m_in = parent.m_in; + let mut xs_row_major: Vec> = vec![vec![F::ZERO; D * m_in]; k_dec]; + for col in 0..m_in { + for rho in 0..D { + let u = z_rows[rho][col].as_canonical_u64() as u128; + if B_u <= i64::MAX as u128 { + let val_opt: Option = if u < B_u { + Some(u as i64) + } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { + Some(-((p - u) as i64)) + } else { + None + }; + let mut v = val_opt.ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "DEC split(X): Z_mix[{},{}] out of range for k_rho={}, b={}", + rho, col, k_dec, params.b + )) + })?; + for i in 0..k_dec { + if v == 0 { + break; + } + let (r_i, q) = balanced_divrem_i64(v, b_i64); + xs_row_major[i][rho * m_in + col] = f_from_i64(r_i); + v = q; + } + if v != 0 { + return Err(PiCcsError::ProtocolError(format!( + "DEC split(X): Z_mix[{},{}] needs more than k_rho={} digits in base b={}", + rho, col, k_dec, params.b + ))); + } + } else { + let val_opt: Option = if u < B_u { + Some(u as i128) + } else if p.checked_sub(u).map(|w| w < B_u).unwrap_or(false) { + Some(-((p - u) as i128)) + } else { + None + }; + let mut v = val_opt.ok_or_else(|| { + PiCcsError::ProtocolError(format!( + "DEC split(X): Z_mix[{},{}] out of range for k_rho={}, b={}", + rho, col, k_dec, params.b + )) + })?; + for i in 0..k_dec { + if v == 0 { + break; + } + let (r_i, q) = balanced_divrem_i128(v, b_i128); + xs_row_major[i][rho * m_in + col] = f_from_i64(r_i as i64); + v = q; + } + if v != 0 { + return Err(PiCcsError::ProtocolError(format!( + "DEC split(X): Z_mix[{},{}] needs more than k_rho={} digits in base b={}", + rho, col, k_dec, params.b + ))); + } + } + } + } + + let parent_r = parent.r.clone(); + let fold_digest = parent.fold_digest; + + let mut children: Vec> = Vec::with_capacity(k_dec); + for i in 0..k_dec { + let Xi = Mat::from_row_major(D, m_in, xs_row_major[i].clone()); + let mut y_i: Vec> = Vec::with_capacity(t_mats); + let mut y_scalars_i: Vec = Vec::with_capacity(t_mats); + for j in 0..t_mats { + let mut yj = vec![K::ZERO; d_pad]; + let row = &acc.y[i * t_mats + j]; + for rho in 0..D { + yj[rho] = row[rho]; + } + let mut sc = K::ZERO; + for rho in 0..D { + sc += yj[rho] * pow_b_k[rho]; + } + y_i.push(yj); + y_scalars_i.push(sc); + } + + let y_zcol = if chi_s.is_empty() { + Vec::new() + } else { + let mut yz = vec![K::ZERO; d_pad]; + let row = &acc.y_zcol[i]; + for rho in 0..D { + yz[rho] = row[rho]; + } + yz + }; + + children.push(MeInstance:: { + c_step_coords: vec![], + u_offset: 0, + u_len: 0, + c: child_cs[i].clone(), + X: Xi, + r: parent_r.clone(), + s_col: parent.s_col.clone(), + y: y_i, + y_scalars: y_scalars_i, + y_zcol, + m_in, + fold_digest, + }); + } + + // Public checks (mirror paper-exact DEC). + let mut ok_y = true; + for j in 0..t_mats { + let mut lhs = vec![K::ZERO; d_pad]; + let mut pow = K::ONE; + for i in 0..k_dec { + for t in 0..d_pad { + lhs[t] += pow * children[i].y[j][t]; + } + pow *= bK; + } + if lhs != parent.y[j] { + ok_y = false; + break; + } + } + + // y_zcol: column-domain opening must also decompose (when present). + if ok_y && !chi_s.is_empty() { + let mut lhs = vec![K::ZERO; d_pad]; + let mut pow = K::ONE; + for i in 0..k_dec { + for t in 0..d_pad { + lhs[t] += pow * children[i].y_zcol[t]; + } + pow *= bK; + } + if lhs != parent.y_zcol { + ok_y = false; + } + } + + let mut lhs_X = Mat::zero(D, m_in, F::ZERO); + let mut pow = F::ONE; + for i in 0..k_dec { + for r in 0..D { + for c in 0..m_in { + lhs_X[(r, c)] += pow * children[i].X[(r, c)]; + } + } + pow *= bF; + } + let ok_X = lhs_X.as_slice() == parent.X.as_slice(); + + let ok_c = combine_b_pows(&child_cs, params.b) == parent.c; + Ok((children, child_cs, ok_y, ok_X, ok_c)) +} + +pub(crate) fn bind_rlc_inputs( + tr: &mut Poseidon2Transcript, + lane: RlcLane, + step_idx: usize, + me_inputs: &[MeInstance], +) -> Result<(), PiCcsError> { + let lane_scope: &'static [u8] = match lane { + RlcLane::Main => b"main", + RlcLane::Val => b"val", + }; + + // v2: binds NC-channel fields (s_col, y_zcol) so RLC challenges depend on the full instance. + tr.append_message(b"fold/rlc_inputs/v2", lane_scope); + tr.append_u64s(b"step_idx", &[step_idx as u64]); + tr.append_u64s(b"me_count", &[me_inputs.len() as u64]); + + for me in me_inputs { + tr.append_fields(b"c_data", &me.c.data); + tr.append_u64s(b"m_in", &[me.m_in as u64]); + tr.append_message(b"me_fold_digest", &me.fold_digest); + + let r_coeffs_per_limb = me.r.first().map(|v| v.as_coeffs().len()).unwrap_or(0); + tr.append_fields_iter( + b"r_limb", + me.r.len() + .checked_mul(r_coeffs_per_limb) + .ok_or_else(|| PiCcsError::ProtocolError("r_limb length overflow".into()))?, + me.r.iter().flat_map(|limb| limb.as_coeffs()), + ); + + tr.append_u64s(b"s_col_len", &[me.s_col.len() as u64]); + let s_col_coeffs_per_elem = me.s_col.first().map(|v| v.as_coeffs().len()).unwrap_or(0); + tr.append_fields_iter( + b"s_col_elem", + me.s_col + .len() + .checked_mul(s_col_coeffs_per_elem) + .ok_or_else(|| PiCcsError::ProtocolError("s_col_elem length overflow".into()))?, + me.s_col.iter().flat_map(|sc| sc.as_coeffs()), + ); + + tr.append_u64s(b"y_zcol_len", &[me.y_zcol.len() as u64]); + let y_zcol_coeffs_per_elem = me.y_zcol.first().map(|v| v.as_coeffs().len()).unwrap_or(0); + tr.append_fields_iter( + b"y_zcol_elem", + me.y_zcol + .len() + .checked_mul(y_zcol_coeffs_per_elem) + .ok_or_else(|| PiCcsError::ProtocolError("y_zcol_elem length overflow".into()))?, + me.y_zcol.iter().flat_map(|yz| yz.as_coeffs()), + ); + + tr.append_fields(b"X", me.X.as_slice()); + + let y_elem_coeffs_per_elem = + me.y.iter() + .find_map(|row| row.first()) + .map(|v| v.as_coeffs().len()) + .unwrap_or(0); + let y_elem_count = me.y.iter().map(Vec::len).sum::(); + tr.append_fields_iter( + b"y_elem", + y_elem_count + .checked_mul(y_elem_coeffs_per_elem) + .ok_or_else(|| PiCcsError::ProtocolError("y_elem length overflow".into()))?, + me.y.iter() + .flat_map(|row| row.iter().flat_map(|v| v.as_coeffs())), + ); + + let y_scalar_coeffs_per_elem = me + .y_scalars + .first() + .map(|v| v.as_coeffs().len()) + .unwrap_or(0); + tr.append_fields_iter( + b"y_scalar", + me.y_scalars + .len() + .checked_mul(y_scalar_coeffs_per_elem) + .ok_or_else(|| PiCcsError::ProtocolError("y_scalar length overflow".into()))?, + me.y_scalars.iter().flat_map(|ysc| ysc.as_coeffs()), + ); + + tr.append_u64s(b"c_step_coords_len", &[me.c_step_coords.len() as u64]); + tr.append_fields(b"c_step_coords", &me.c_step_coords); + tr.append_u64s(b"u_offset", &[me.u_offset as u64]); + tr.append_u64s(b"u_len", &[me.u_len as u64]); + } + + Ok(()) +} diff --git a/crates/neo-fold/src/shard/prover.rs b/crates/neo-fold/src/shard/prover.rs new file mode 100644 index 00000000..74f74fc8 --- /dev/null +++ b/crates/neo-fold/src/shard/prover.rs @@ -0,0 +1,1271 @@ +use super::*; + +#[derive(Clone)] +pub(crate) struct ShardProverContext { + pub ccs_mat_digest: Vec, + pub ccs_sparse_cache: Option>>, +} + +#[inline] +pub(crate) fn mode_uses_sparse_cache(mode: &FoldingMode) -> bool { + match mode { + FoldingMode::Optimized => true, + #[cfg(feature = "paper-exact")] + FoldingMode::OptimizedWithCrosscheck(_) => true, + #[cfg(feature = "paper-exact")] + FoldingMode::PaperExact => false, + } +} + +pub(crate) fn fold_shard_prove_impl( + collect_val_lane_wits: bool, + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepWitnessBundle], + step_idx_offset: usize, + acc_init: &[MeInstance], + acc_wit_init: &[Mat], + l: &L, + mixers: CommitMixers, + ob: Option<(&crate::output_binding::OutputBindingConfig, &[F])>, + prover_ctx: Option<&ShardProverContext>, + mut step_prove_ms_out: Option<&mut Vec>, +) -> Result<(ShardProof, Vec>, Vec>), PiCcsError> +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + for (step_idx, step) in steps.iter().enumerate() { + if step.lut_instances.is_empty() && step.mem_instances.is_empty() { + continue; + } + let is_shared_step = step + .lut_instances + .iter() + .all(|(inst, wit)| inst.comms.is_empty() && wit.mats.is_empty()) + && step + .mem_instances + .iter() + .all(|(inst, wit)| inst.comms.is_empty() && wit.mats.is_empty()); + if !is_shared_step { + return Err(PiCcsError::InvalidInput(format!( + "legacy no-shared CPU bus mode was removed; step_idx={step_idx} must use shared-bus witness format" + ))); + } + } + tr.append_message(b"shard/cpu_bus_mode", &[1u8]); + let (s, cpu_bus) = crate::memory_sidecar::cpu_bus::prepare_ccs_for_shared_cpu_bus_steps(s_me, steps)?; + let dims = utils::build_dims_and_policy(params, s)?; + let utils::Dims { + ell_d, + ell_n, + ell_m, + ell, + d_sc, + .. + } = dims; + let ccs_sparse_cache: Option>> = if mode_uses_sparse_cache(&mode) { + Some( + prover_ctx + .and_then(|ctx| ctx.ccs_sparse_cache.clone()) + .unwrap_or_else(|| Arc::new(SparseCache::build(s))), + ) + } else { + None + }; + let ccs_mat_digest = prover_ctx + .map(|ctx| ctx.ccs_mat_digest.clone()) + .unwrap_or_else(|| utils::digest_ccs_matrices_with_sparse_cache(s, ccs_sparse_cache.as_deref())); + if mode_uses_sparse_cache(&mode) && ccs_sparse_cache.is_none() { + return Err(PiCcsError::ProtocolError( + "missing SparseCache for optimized mode".into(), + )); + } + let k_dec = params.k_rho as usize; + let ring = ccs::RotRing::goldilocks(); + + if acc_init.len() != acc_wit_init.len() { + return Err(PiCcsError::InvalidInput(format!( + "acc_init.len()={} != acc_wit_init.len()={}", + acc_init.len(), + acc_wit_init.len() + ))); + } + + // Initialize accumulator + let mut accumulator = acc_init.to_vec(); + let mut accumulator_wit = acc_wit_init.to_vec(); + + let mut step_proofs = Vec::with_capacity(steps.len()); + let mut val_lane_wits: Vec> = Vec::new(); + let mut prev_twist_decoded: Option> = None; + let mut output_proof: Option = None; + + if ob.is_some() && steps.is_empty() { + return Err(PiCcsError::InvalidInput("output binding requires >= 1 step".into())); + } + + for (idx, step) in steps.iter().enumerate() { + let step_idx = step_idx_offset + .checked_add(idx) + .ok_or_else(|| PiCcsError::InvalidInput("step index overflow".into()))?; + let step_start = time_now(); + crate::memory_sidecar::memory::absorb_step_memory_witness(tr, step); + + let include_ob = ob.is_some() && (idx + 1 == steps.len()); + let mut wb_time_claim: Option = None; + let mut wp_time_claim: Option = None; + let mut decode_decode_fields_claim: Option = None; + let mut decode_decode_immediates_claim: Option = + None; + let mut width_bitness_claim: Option = None; + let mut width_quiescence_claim: Option = None; + let mut width_load_semantics_claim: Option = None; + let mut width_store_semantics_claim: Option = None; + let mut control_next_pc_linear_claim: Option = None; + let mut control_next_pc_control_claim: Option = + None; + let mut control_branch_semantics_claim: Option = + None; + let mut control_control_writeback_claim: Option = + None; + let mut ob_time_claim: Option = None; + let mut ob_r_prime: Option> = None; + + // Output binding is injected only on the final step, and must run before sampling Route-A `r_time`. + if include_ob { + let (cfg, final_memory_state) = + ob.ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but config missing".into()))?; + + if output_proof.is_some() { + return Err(PiCcsError::ProtocolError( + "output binding already attached (internal error)".into(), + )); + } + + if cfg.mem_idx >= step.mem_instances.len() { + return Err(PiCcsError::InvalidInput("output binding mem_idx out of range".into())); + } + let expected_k = 1usize + .checked_shl(cfg.num_bits as u32) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^num_bits overflow".into()))?; + if final_memory_state.len() != expected_k { + return Err(PiCcsError::InvalidInput(format!( + "output binding: final_memory_state.len()={} != 2^num_bits={}", + final_memory_state.len(), + expected_k + ))); + } + let mem_inst = &step.mem_instances[cfg.mem_idx].0; + if mem_inst.k != expected_k { + return Err(PiCcsError::InvalidInput(format!( + "output binding: cfg.num_bits implies k={}, but mem_inst.k={}", + expected_k, mem_inst.k + ))); + } + let ell_addr = mem_inst.twist_layout().lanes[0].ell_addr; + if ell_addr != cfg.num_bits { + return Err(PiCcsError::InvalidInput(format!( + "output binding: cfg.num_bits={}, but twist_layout.ell_addr={}", + cfg.num_bits, ell_addr + ))); + } + + tr.append_message(b"shard/output_binding_start", &(step_idx as u64).to_le_bytes()); + tr.append_u64s(b"output_binding/mem_idx", &[cfg.mem_idx as u64]); + tr.append_u64s(b"output_binding/num_bits", &[cfg.num_bits as u64]); + + let (output_sc, r_prime) = neo_memory::output_check::generate_output_sumcheck_proof_and_challenges( + tr, + cfg.num_bits, + cfg.program_io.clone(), + final_memory_state, + ) + .map_err(|e| PiCcsError::ProtocolError(format!("output sumcheck failed: {e:?}")))?; + + output_proof = Some(neo_memory::output_check::OutputBindingProof { output_sc }); + ob_r_prime = Some(r_prime); + } + + let (mcs_inst, mcs_wit) = &step.mcs; + + // k = accumulator.len() + 1 + let k = accumulator.len() + 1; + + // -------------------------------------------------------------------- + // Route A: Shared-challenge batched sum-check for time/row rounds. + // -------------------------------------------------------------------- + // + // 1) Bind CCS header + ME inputs + // 2) Sample CCS challenges (α, β, γ) and bind initial sum + // 3) Build CCS oracle + lazy Twist/Shout oracles + // 4) Run ONE batched sum-check for the first ell_n rounds (row/time) + // 5) Finish CCS alone for remaining ell_d Ajtai rounds + // 6) Emit CCS + memory ME claims at the shared r_time and fold via RLC/DEC + + utils::bind_header_and_instances_with_digest( + tr, + params, + &s, + core::slice::from_ref(mcs_inst), + dims, + &ccs_mat_digest, + )?; + utils::bind_me_inputs(tr, &accumulator)?; + let mut ch = utils::sample_challenges(tr, ell_d, ell)?; + ch.beta_m = utils::sample_beta_m(tr, ell_m)?; + let ccs_initial_sum = claimed_initial_sum_from_inputs(&s, &ch, &accumulator); + tr.append_fields(b"sumcheck/initial_sum", &ccs_initial_sum.as_coeffs()); + + // Route A memory checks use a separate transcript-derived cycle point `r_cycle` + // to form χ_{r_cycle}(t) weights inside their sum-check polynomials. + let r_cycle: Vec = + ts::sample_ext_point(tr, b"route_a/r_cycle", b"route_a/cycle/0", b"route_a/cycle/1", ell_n); + + // CCS oracle (engine-selected). + // + // Keep the optimized oracle concrete so we can build outputs from its Ajtai precompute. + let mut ccs_oracle: CcsOracleDispatch<'_> = match mode.clone() { + FoldingMode::Optimized => { + let sparse = ccs_sparse_cache + .as_ref() + .ok_or_else(|| PiCcsError::ProtocolError("missing SparseCache for optimized mode".into()))?; + CcsOracleDispatch::Optimized( + neo_reductions::engines::optimized_engine::oracle::OptimizedOracle::new_with_sparse( + &s, + params, + core::slice::from_ref(mcs_wit), + &accumulator_wit, + ch.clone(), + ell_d, + ell_n, + d_sc, + accumulator.first().map(|mi| mi.r.as_slice()), + sparse.clone(), + ), + ) + } + #[cfg(feature = "paper-exact")] + FoldingMode::PaperExact => CcsOracleDispatch::PaperExact( + neo_reductions::engines::paper_exact_engine::oracle::PaperExactOracle::new( + &s, + params, + core::slice::from_ref(mcs_wit), + &accumulator_wit, + ch.clone(), + ell_d, + ell_n, + d_sc, + accumulator.first().map(|mi| mi.r.as_slice()), + ), + ), + #[cfg(feature = "paper-exact")] + FoldingMode::OptimizedWithCrosscheck(_) => { + let sparse = ccs_sparse_cache + .as_ref() + .ok_or_else(|| PiCcsError::ProtocolError("missing SparseCache for optimized mode".into()))?; + CcsOracleDispatch::Optimized( + neo_reductions::engines::optimized_engine::oracle::OptimizedOracle::new_with_sparse( + &s, + params, + core::slice::from_ref(mcs_wit), + &accumulator_wit, + ch.clone(), + ell_d, + ell_n, + d_sc, + accumulator.first().map(|mi| mi.r.as_slice()), + sparse.clone(), + ), + ) + } + }; + + let shout_pre = crate::memory_sidecar::memory::prove_shout_addr_pre_time( + tr, params, step, &cpu_bus, ell_n, &r_cycle, step_idx, + )?; + + let twist_pre = + crate::memory_sidecar::memory::prove_twist_addr_pre_time(tr, params, step, &cpu_bus, ell_n, &r_cycle)?; + let twist_read_claims: Vec = twist_pre.iter().map(|p| p.read_check_claim_sum).collect(); + let twist_write_claims: Vec = twist_pre.iter().map(|p| p.write_check_claim_sum).collect(); + let mut mem_oracles = crate::memory_sidecar::memory::build_route_a_memory_oracles( + params, step, ell_n, &r_cycle, &shout_pre, &twist_pre, + )?; + + let (wb_time_claim_built, wp_time_claim_built) = + crate::memory_sidecar::memory::build_route_a_wb_wp_time_claims(params, step, &r_cycle)?; + let wb_wp_required = crate::memory_sidecar::memory::wb_wp_required_for_step_witness(step); + if wb_wp_required && (wb_time_claim_built.is_none() || wp_time_claim_built.is_none()) { + return Err(PiCcsError::ProtocolError( + "WB/WP claims are required in RV32 trace mode but were not built".into(), + )); + } + if let Some((oracle, _claimed_sum)) = wb_time_claim_built { + wb_time_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"wb/booleanity", + }); + } + if let Some((oracle, _claimed_sum)) = wp_time_claim_built { + wp_time_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"wp/quiescence", + }); + } + let (decode_decode_fields_built, decode_decode_immediates_built) = + crate::memory_sidecar::memory::build_route_a_decode_time_claims(params, step, &r_cycle)?; + let decode_required = crate::memory_sidecar::memory::decode_stage_required_for_step_witness(step); + if decode_required && (decode_decode_fields_built.is_none() || decode_decode_immediates_built.is_none()) { + return Err(PiCcsError::ProtocolError( + "decode stage claims are required in RV32 trace mode but were not built".into(), + )); + } + if let Some((oracle, _claimed_sum)) = decode_decode_fields_built { + decode_decode_fields_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"decode/fields", + }); + } + if let Some((oracle, _claimed_sum)) = decode_decode_immediates_built { + decode_decode_immediates_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"decode/immediates", + }); + } + let ( + width_bitness_built, + width_quiescence_built, + _width_selector_linkage_built, + width_load_semantics_built, + width_store_semantics_built, + ) = crate::memory_sidecar::memory::build_route_a_width_time_claims(params, step, &r_cycle)?; + let width_required = crate::memory_sidecar::memory::width_stage_required_for_step_witness(step); + if width_required + && (width_bitness_built.is_none() + || width_quiescence_built.is_none() + || width_load_semantics_built.is_none() + || width_store_semantics_built.is_none()) + { + return Err(PiCcsError::ProtocolError( + "width stage claims are required in RV32 trace mode but were not built".into(), + )); + } + if let Some((oracle, _claimed_sum)) = width_bitness_built { + width_bitness_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"width/bitness", + }); + } + if let Some((oracle, _claimed_sum)) = width_quiescence_built { + width_quiescence_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"width/quiescence", + }); + } + if let Some((oracle, _claimed_sum)) = width_load_semantics_built { + width_load_semantics_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"width/load_semantics", + }); + } + if let Some((oracle, _claimed_sum)) = width_store_semantics_built { + width_store_semantics_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"width/store_semantics", + }); + } + let ( + control_next_pc_linear_built, + control_next_pc_control_built, + control_branch_semantics_built, + control_control_writeback_built, + ) = crate::memory_sidecar::memory::build_route_a_control_time_claims(params, step, &r_cycle)?; + let control_required = crate::memory_sidecar::memory::control_stage_required_for_step_witness(step); + if control_required + && (control_next_pc_linear_built.is_none() + || control_next_pc_control_built.is_none() + || control_branch_semantics_built.is_none() + || control_control_writeback_built.is_none()) + { + return Err(PiCcsError::ProtocolError( + "control stage claims are required in RV32 trace mode but were not built".into(), + )); + } + if let Some((oracle, _claimed_sum)) = control_next_pc_linear_built { + control_next_pc_linear_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"control/next_pc_linear", + }); + } + if let Some((oracle, _claimed_sum)) = control_next_pc_control_built { + control_next_pc_control_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"control/next_pc_control", + }); + } + if let Some((oracle, _claimed_sum)) = control_branch_semantics_built { + control_branch_semantics_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"control/branch_semantics", + }); + } + if let Some((oracle, _claimed_sum)) = control_control_writeback_built { + control_control_writeback_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle, + claimed_sum: K::ZERO, + label: b"control/writeback", + }); + } + + if include_ob { + let (cfg, _final_memory_state) = + ob.ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but config missing".into()))?; + let r_prime = ob_r_prime + .as_ref() + .ok_or_else(|| PiCcsError::ProtocolError("output binding r_prime missing".into()))?; + let pre = twist_pre + .get(cfg.mem_idx) + .ok_or_else(|| PiCcsError::ProtocolError("output binding mem_idx out of range for twist_pre".into()))?; + + if pre.decoded.lanes.is_empty() { + return Err(PiCcsError::ProtocolError( + "output binding: Twist decoded lanes empty".into(), + )); + } + + let mut oracles: Vec> = Vec::with_capacity(pre.decoded.lanes.len()); + let mut claimed_sum = K::ZERO; + for lane in pre.decoded.lanes.iter() { + let (oracle, claim) = neo_memory::twist_oracle::TwistTotalIncOracleSparseTime::new( + lane.wa_bits.clone(), + lane.has_write.clone(), + lane.inc_at_write_addr.clone(), + r_prime, + ); + oracles.push(Box::new(oracle)); + claimed_sum += claim; + } + let oracle = crate::memory_sidecar::memory::SumRoundOracle::new(oracles); + + ob_time_claim = Some(crate::memory_sidecar::route_a_time::ExtraBatchedTimeClaim { + oracle: Box::new(oracle), + claimed_sum, + label: crate::output_binding::OB_INC_TOTAL_LABEL, + }); + } + + let crate::memory_sidecar::route_a_time::RouteABatchedTimeProverOutput { + r_time, + per_claim_results, + proof: batched_time, + } = crate::memory_sidecar::route_a_time::prove_route_a_batched_time( + tr, + step_idx, + ell_n, + d_sc, + ccs_initial_sum, + &mut ccs_oracle, + &mut mem_oracles, + step, + twist_read_claims, + twist_write_claims, + wb_time_claim, + wp_time_claim, + decode_decode_fields_claim, + decode_decode_immediates_claim, + width_bitness_claim, + width_quiescence_claim, + None, + width_load_semantics_claim, + width_store_semantics_claim, + control_next_pc_linear_claim, + control_next_pc_control_claim, + control_branch_semantics_claim, + control_control_writeback_claim, + ob_time_claim, + )?; + + // Finish CCS Ajtai rounds alone, continuing from the CCS oracle state after ell_n folds. + let ccs_time_rounds = per_claim_results + .first() + .map(|r| r.round_polys.clone()) + .unwrap_or_default(); + let mut sumcheck_rounds = ccs_time_rounds; + let mut sumcheck_chals = r_time.clone(); + let ajtai_initial_sum = per_claim_results + .first() + .map(|r| r.final_value) + .unwrap_or(ccs_initial_sum); + + let mut ccs_ajtai = RoundOraclePrefix::new(&mut ccs_oracle, ell_d); + let (ajtai_rounds, ajtai_chals) = + run_sumcheck_prover_ds(tr, b"ccs/ajtai", step_idx, &mut ccs_ajtai, ajtai_initial_sum)?; + let mut running_sum = ajtai_initial_sum; + for (round_poly, &r_i) in ajtai_rounds.iter().zip(ajtai_chals.iter()) { + running_sum = poly_eval_k(round_poly, r_i); + } + sumcheck_rounds.extend_from_slice(&ajtai_rounds); + sumcheck_chals.extend_from_slice(&ajtai_chals); + + // -------------------------------------------------------------------- + // NC-only sumcheck (digit-range / norm-check) over {0,1}^{ell_m + ell_d}. + // -------------------------------------------------------------------- + let mut ccs_nc_oracle = neo_reductions::engines::optimized_engine::oracle::NcOracle::new( + &s, + params, + core::slice::from_ref(mcs_wit), + &accumulator_wit, + ch.clone(), + ell_d, + ell_m, + d_sc, + ); + let (sumcheck_rounds_nc, sumcheck_chals_nc) = + run_sumcheck_prover_ds(tr, b"ccs/nc", step_idx, &mut ccs_nc_oracle, K::ZERO)?; + let mut running_sum_nc = K::ZERO; + for (round_poly, &r_i) in sumcheck_rounds_nc.iter().zip(sumcheck_chals_nc.iter()) { + running_sum_nc = poly_eval_k(round_poly, r_i); + } + let (s_col, _alpha_prime_nc) = sumcheck_chals_nc.split_at(ell_m); + + // Build CCS ME outputs at r_time. + let fold_digest = tr.digest32(); + let mut ccs_out = match &mut ccs_oracle { + CcsOracleDispatch::Optimized(oracle) => oracle.build_me_outputs_from_ajtai_precomp( + core::slice::from_ref(mcs_inst), + &accumulator, + s_col, + fold_digest, + l, + ), + #[cfg(feature = "paper-exact")] + CcsOracleDispatch::PaperExact(_) => build_me_outputs_paper_exact( + &s, + params, + core::slice::from_ref(mcs_inst), + core::slice::from_ref(mcs_wit), + &accumulator, + &accumulator_wit, + &r_time, + s_col, + ell_d, + fold_digest, + l, + ), + }; + + // CCS oracle borrows accumulator_wit; drop before updating accumulator_wit at the end. + drop(ccs_oracle); + + let mut trace_linkage_t_len: Option = None; + + // Shared CPU bus: append "implicit openings" for all bus columns without materializing + // bus copyout matrices into the CCS. + if cpu_bus.bus_cols > 0 { + let core_t = s.t(); + if ccs_out.len() != 1 + accumulator_wit.len() { + return Err(PiCcsError::ProtocolError(format!( + "CCS output count mismatch for bus openings (ccs_out.len()={}, expected {})", + ccs_out.len(), + 1 + accumulator_wit.len() + ))); + } + + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( + params, + &cpu_bus, + core_t, + &mcs_wit.Z, + &mut ccs_out[0], + )?; + for (out, Z) in ccs_out.iter_mut().skip(1).zip(accumulator_wit.iter()) { + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance(params, &cpu_bus, core_t, Z, out)?; + } + } + + // For RV32 trace wiring CCS, append time-combined openings for trace columns needed to + // link Twist/Shout sidecars at r_time. In shared-bus mode this is appended after bus openings. + if (!step.mem_instances.is_empty() || !step.lut_instances.is_empty()) && mcs_inst.m_in == 5 { + // Infer that the CPU witness is the RV32 trace column-major layout: + // z = [x (m_in) | trace_cols * t_len] + let m_in = mcs_inst.m_in; + let t_len = step + .mem_instances + .first() + .map(|(inst, _wit)| inst.steps) + .or_else(|| { + // Shout event-table instances may have `steps != t_len`; prefer a non-event-table + // instance if present, otherwise fall back to inferring from the trace layout. + step.lut_instances + .iter() + .find(|(inst, _wit)| { + !matches!(inst.table_spec, Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. })) + }) + .map(|(inst, _wit)| inst.steps) + }) + .or_else(|| { + // Trace CCS layout inference: z = [x (m_in) | trace_cols * t_len] + let trace = Rv32TraceLayout::new(); + let w = s.m.checked_sub(m_in)?; + if trace.cols == 0 || w % trace.cols != 0 { + return None; + } + Some(w / trace.cols) + }) + .ok_or_else(|| PiCcsError::InvalidInput("missing mem/lut instances".into()))?; + if t_len == 0 { + return Err(PiCcsError::InvalidInput("trace linkage requires steps>=1".into())); + } + for (i, (inst, _wit)) in step.mem_instances.iter().enumerate() { + if inst.steps != t_len { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage requires stable steps across mem instances (mem_idx={i} has steps={}, expected {t_len})", + inst.steps + ))); + } + } + + let trace = Rv32TraceLayout::new(); + let trace_len = trace + .cols + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace cols * t_len overflow".into()))?; + let expected_m = m_in + .checked_add(trace_len) + .ok_or_else(|| PiCcsError::InvalidInput("m_in + trace_len overflow".into()))?; + if s.m < expected_m { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage expects m >= m_in + trace.cols*t_len (m={}; min_m={expected_m} for t_len={t_len}, trace_cols={})", + s.m, trace.cols + ))); + } + + let trace_cols_to_open_dense: Vec = vec![ + trace.active, + trace.cycle, + trace.pc_before, + trace.instr_word, + trace.rs1_addr, + trace.rs1_val, + trace.rs2_addr, + trace.rs2_val, + trace.rd_addr, + trace.rd_val, + trace.ram_addr, + trace.ram_rv, + trace.ram_wv, + ]; + let trace_cols_to_open_shout: Vec = vec![ + trace.shout_has_lookup, + trace.shout_val, + trace.shout_lhs, + trace.shout_rhs, + ]; + let trace_cols_to_open_all: Vec = trace_cols_to_open_dense + .iter() + .chain(trace_cols_to_open_shout.iter()) + .copied() + .collect(); + let core_t = s.t(); + let trace_open_base = core_t + cpu_bus.bus_cols; + let col_base = m_in; // trace_base in the RV32 trace layout + + // Event-table style micro-optimization: Shout trace columns are constrained to be 0 + // whenever `shout_has_lookup == 0`, so we can compute their openings by summing only + // over the active lookup rows. + let active_shout_js: Vec = { + let d = neo_math::D; + let mut out: Vec = Vec::new(); + let col_offset = trace + .shout_has_lookup + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace col_id * t_len overflow".into()))?; + for j in 0..t_len { + let z_idx = col_base + .checked_add(col_offset) + .and_then(|x| x.checked_add(j)) + .ok_or_else(|| PiCcsError::InvalidInput("trace z index overflow".into()))?; + if z_idx >= mcs_wit.Z.cols() { + return Err(PiCcsError::InvalidInput(format!( + "trace openings: z_idx out of range (z_idx={z_idx}, m={})", + mcs_wit.Z.cols() + ))); + } + + let mut any = false; + for rho in 0..d { + if mcs_wit.Z[(rho, z_idx)] != F::ZERO { + any = true; + break; + } + } + if any { + out.push(j); + } + } + out + }; + + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + col_base, + &trace_cols_to_open_dense, + trace_open_base, + &mcs_wit.Z, + &mut ccs_out[0], + )?; + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance_at_js( + params, + m_in, + t_len, + col_base, + &trace_cols_to_open_shout, + trace_open_base + trace_cols_to_open_dense.len(), + &mcs_wit.Z, + &mut ccs_out[0], + &active_shout_js, + )?; + for (out, Z) in ccs_out.iter_mut().skip(1).zip(accumulator_wit.iter()) { + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + col_base, + &trace_cols_to_open_all, + trace_open_base, + Z, + out, + )?; + } + trace_linkage_t_len = Some(t_len); + } + + if ccs_out.len() != k { + return Err(PiCcsError::ProtocolError(format!( + "Π_CCS returned {} outputs; expected k={k}", + ccs_out.len() + ))); + } + + let mut ccs_proof = crate::PiCcsProof::new(sumcheck_rounds, Some(ccs_initial_sum)); + ccs_proof.variant = crate::optimized_engine::PiCcsProofVariant::SplitNcV1; + ccs_proof.sumcheck_challenges = sumcheck_chals; + ccs_proof.sumcheck_rounds_nc = sumcheck_rounds_nc; + ccs_proof.sc_initial_sum_nc = Some(K::ZERO); + ccs_proof.sumcheck_challenges_nc = sumcheck_chals_nc; + ccs_proof.challenges_public = ch; + ccs_proof.sumcheck_final = running_sum; + ccs_proof.sumcheck_final_nc = running_sum_nc; + ccs_proof.header_digest = fold_digest.to_vec(); + + #[cfg(feature = "paper-exact")] + if let FoldingMode::OptimizedWithCrosscheck(cfg) = &mode { + crosscheck_route_a_ccs_step( + cfg, + step_idx, + params, + &s, + &cpu_bus, + mcs_inst, + mcs_wit, + &accumulator, + &accumulator_wit, + &ccs_out, + &ccs_proof, + ell_d, + ell_n, + ell_m, + d_sc, + fold_digest, + l, + )?; + } + + // Witnesses for CCS outputs: [Z_mcs, Z_seed...] (borrow; avoid multi-GB clones) + let mut outs_Z: Vec<&Mat> = Vec::with_capacity(k); + outs_Z.push(&mcs_wit.Z); + outs_Z.extend(accumulator_wit.iter()); + + // Memory sidecar: emit ME claims at the shared r_time (no fixed-challenge sumcheck). + let prev_step = (idx > 0).then(|| &steps[idx - 1]); + let prev_twist_decoded_ref = prev_twist_decoded.as_deref(); + let mut mem_proof = crate::memory_sidecar::memory::finalize_route_a_memory_prover( + tr, + params, + &cpu_bus, + &s, + step, + prev_step, + prev_twist_decoded_ref, + &mut mem_oracles, + &shout_pre.addr_pre, + &twist_pre, + &r_time, + mcs_inst.m_in, + step_idx, + )?; + prev_twist_decoded = Some(twist_pre.into_iter().map(|p| p.decoded).collect()); + + // Normalize ME claim shapes for per-claim folding lanes. + for me in mem_proof.val_me_claims.iter_mut() { + let t = me.y.len(); + normalize_me_claims(core::slice::from_mut(me), ell_n, ell_d, t)?; + } + for me in mem_proof.wb_me_claims.iter_mut() { + let t = me.y.len(); + normalize_me_claims(core::slice::from_mut(me), ell_n, ell_d, t)?; + } + for me in mem_proof.wp_me_claims.iter_mut() { + let t = me.y.len(); + normalize_me_claims(core::slice::from_mut(me), ell_n, ell_d, t)?; + } + + validate_me_batch_invariants(&ccs_out, "prove step ccs outputs")?; + + let want_main_wits = collect_val_lane_wits || idx + 1 < steps.len(); + let (main_fold, Z_split) = prove_rlc_dec_lane( + &mode, + RlcLane::Main, + tr, + params, + &s, + ccs_sparse_cache.as_deref(), + Some(&cpu_bus), + &ring, + ell_d, + k_dec, + step_idx, + trace_linkage_t_len, + &ccs_out, + &outs_Z, + want_main_wits, + l, + mixers, + )?; + let RlcDecProof { + rlc_rhos: rhos, + rlc_parent: parent_pub, + dec_children: children, + } = main_fold; + + let has_prev = prev_step.is_some(); + + // -------------------------------------------------------------------- + // Phase 2: Second folding lane for Twist val-eval ME claims at r_val. + // -------------------------------------------------------------------- + let mut val_fold: Vec = Vec::new(); + if !mem_proof.val_me_claims.is_empty() { + tr.append_message(b"fold/val_lane_start", &(step_idx as u64).to_le_bytes()); + let expected = 1usize + usize::from(has_prev); + if mem_proof.val_me_claims.len() != expected { + return Err(PiCcsError::ProtocolError(format!( + "Twist(val) claim count mismatch (have {}, expected {})", + mem_proof.val_me_claims.len(), + expected + ))); + } + let can_reuse_main_lane_dec = + ccs_out.len() == 1 && outs_Z.len() == 1 && !Z_split.is_empty() && children.len() == Z_split.len(); + let shared_val_lane_child_cs: Option> = if can_reuse_main_lane_dec { + Some(children.iter().map(|child| child.c.clone()).collect()) + } else { + None + }; + + for (claim_idx, me) in mem_proof.val_me_claims.iter().enumerate() { + let (wit, ctx) = match claim_idx { + 0 => (&mcs_wit.Z, "cpu"), + 1 => { + let prev = prev_step + .ok_or_else(|| PiCcsError::ProtocolError("missing prev_step for r_val claim".into()))?; + (&prev.mcs.1.Z, "cpu_prev") + } + _ => { + return Err(PiCcsError::ProtocolError( + "unexpected extra r_val ME claim in shared-bus mode".into(), + )); + } + }; + tr.append_message(b"fold/val_lane_claim_idx", &(claim_idx as u64).to_le_bytes()); + tr.append_message(b"fold/val_lane_claim_ctx", ctx.as_bytes()); + + // Reuse main-lane split/commit artifacts for the current-step shared-bus + // val lane so we don't pay an extra full split+commit. + if claim_idx == 0 { + if let Some(child_cs) = shared_val_lane_child_cs.as_ref() { + bind_rlc_inputs(tr, RlcLane::Val, step_idx, core::slice::from_ref(me))?; + let rlc_rhos = ccs::sample_rot_rhos_n(tr, params, &ring, 1)?; + let mut rlc_parent = ccs::rlc_public( + &s, + params, + &rlc_rhos, + core::slice::from_ref(me), + mixers.mix_rhos_commits, + ell_d, + )?; + let (mut dec_children, ok_y, ok_x, ok_c) = ccs::dec_children_with_commit_cached( + mode.clone(), + &s, + params, + &rlc_parent, + &Z_split, + ell_d, + child_cs, + mixers.combine_b_pows, + ccs_sparse_cache.as_deref(), + ); + if !(ok_y && ok_x && ok_c) { + return Err(PiCcsError::ProtocolError(format!( + "DEC(val) public check failed at step {} (y={}, X={}, c={})", + step_idx, ok_y, ok_x, ok_c + ))); + } + if cpu_bus.bus_cols > 0 { + let core_t = s.t(); + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( + params, + &cpu_bus, + core_t, + wit, + &mut rlc_parent, + )?; + for (child, zi) in dec_children.iter_mut().zip(Z_split.iter()) { + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( + params, &cpu_bus, core_t, zi, child, + )?; + } + } + if collect_val_lane_wits { + val_lane_wits.extend(Z_split.iter().cloned()); + } + val_fold.push(RlcDecProof { + rlc_rhos, + rlc_parent, + dec_children, + }); + continue; + } + } + + let (proof, mut Z_split_val) = prove_rlc_dec_lane( + &mode, + RlcLane::Val, + tr, + params, + &s, + ccs_sparse_cache.as_deref(), + Some(&cpu_bus), + &ring, + ell_d, + k_dec, + step_idx, + None, + core::slice::from_ref(me), + core::slice::from_ref(&wit), + collect_val_lane_wits, + l, + mixers, + )?; + if collect_val_lane_wits { + val_lane_wits.extend(Z_split_val.drain(..)); + } + val_fold.push(proof); + } + } + + // Additional WB/WP folding lane(s): CPU ME openings used by wb/booleanity and + // wp/quiescence stages. These lanes share the same witness matrix (`mcs_wit.Z`), + // so precompute DEC digit witnesses + child commitments once per step. + let mut wb_wp_dec_wits: Option>> = None; + let mut wb_wp_child_cs: Option> = None; + if !mem_proof.wb_me_claims.is_empty() || !mem_proof.wp_me_claims.is_empty() { + let (dec_wits, digit_nonzero) = ccs::split_b_matrix_k_with_nonzero_flags(&mcs_wit.Z, k_dec, params.b)?; + let zero_c = Cmt::zeros(mcs_inst.c.d, mcs_inst.c.kappa); + let child_cs: Vec = { + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + const PAR_CHILD_COMMIT_THRESHOLD: usize = 32; + let use_parallel = dec_wits.len() >= PAR_CHILD_COMMIT_THRESHOLD && rayon::current_num_threads() > 1; + if use_parallel { + dec_wits + .par_iter() + .enumerate() + .map(|(idx, Zi)| { + if digit_nonzero[idx] { + l.commit(Zi) + } else { + zero_c.clone() + } + }) + .collect() + } else { + dec_wits + .iter() + .enumerate() + .map(|(idx, Zi)| { + if digit_nonzero[idx] { + l.commit(Zi) + } else { + zero_c.clone() + } + }) + .collect() + } + } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + dec_wits + .iter() + .enumerate() + .map(|(idx, Zi)| { + if digit_nonzero[idx] { + l.commit(Zi) + } else { + zero_c.clone() + } + }) + .collect() + } + }; + wb_wp_dec_wits = Some(dec_wits); + wb_wp_child_cs = Some(child_cs); + } + + // Additional WB folding lane(s): CPU ME openings used by wb/booleanity stage. + let mut wb_fold: Vec = Vec::new(); + if !mem_proof.wb_me_claims.is_empty() { + let trace = Rv32TraceLayout::new(); + let t_len = crate::memory_sidecar::memory::infer_rv32_trace_t_len_for_wb_wp(step, &trace)?; + let wb_cols = crate::memory_sidecar::memory::rv32_trace_wb_columns(&trace); + let core_t = s.t(); + let m_in = mcs_inst.m_in; + let dec_wits = wb_wp_dec_wits + .as_ref() + .ok_or_else(|| PiCcsError::ProtocolError("WB fold missing shared DEC witnesses".into()))?; + let child_cs = wb_wp_child_cs + .as_ref() + .ok_or_else(|| PiCcsError::ProtocolError("WB fold missing shared DEC commitments".into()))?; + tr.append_message(b"fold/wb_lane_start", &(step_idx as u64).to_le_bytes()); + for (claim_idx, me) in mem_proof.wb_me_claims.iter().enumerate() { + tr.append_message(b"fold/wb_lane_claim_idx", &(claim_idx as u64).to_le_bytes()); + bind_rlc_inputs(tr, RlcLane::Val, step_idx, core::slice::from_ref(me))?; + let rlc_rhos = ccs::sample_rot_rhos_n(tr, params, &ring, 1)?; + let rlc_parent = ccs::rlc_public( + &s, + params, + &rlc_rhos, + core::slice::from_ref(me), + mixers.mix_rhos_commits, + ell_d, + )?; + let (mut dec_children, ok_y, ok_x, ok_c) = ccs::dec_children_with_commit_cached( + mode.clone(), + &s, + params, + &rlc_parent, + dec_wits, + ell_d, + child_cs, + mixers.combine_b_pows, + ccs_sparse_cache.as_deref(), + ); + if !(ok_y && ok_x && ok_c) { + return Err(PiCcsError::ProtocolError(format!( + "DEC(val) public check failed at step {} (y={}, X={}, c={})", + step_idx, ok_y, ok_x, ok_c + ))); + } + if dec_children.len() != dec_wits.len() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: WB fold requires materialized DEC witnesses (children={}, wits={})", + step_idx, + dec_children.len(), + dec_wits.len() + ))); + } + for (child, zi) in dec_children.iter_mut().zip(dec_wits.iter()) { + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, m_in, t_len, m_in, &wb_cols, core_t, zi, child, + )?; + } + if collect_val_lane_wits { + val_lane_wits.extend(dec_wits.iter().cloned()); + } + wb_fold.push(RlcDecProof { + rlc_rhos, + rlc_parent, + dec_children, + }); + } + } + + // Additional WP folding lane(s): CPU ME openings used by wp/quiescence stage. + let mut wp_fold: Vec = Vec::new(); + if !mem_proof.wp_me_claims.is_empty() { + let trace = Rv32TraceLayout::new(); + let t_len = crate::memory_sidecar::memory::infer_rv32_trace_t_len_for_wb_wp(step, &trace)?; + let mut wp_open_cols = crate::memory_sidecar::memory::rv32_trace_wp_opening_columns(&trace); + if control_required { + wp_open_cols.extend(crate::memory_sidecar::memory::rv32_trace_control_extra_opening_columns( + &trace, + )); + } + if decode_required { + let decode_layout = Rv32DecodeSidecarLayout::new(); + let (_decode_open_cols, decode_lut_indices) = + crate::memory_sidecar::memory::resolve_shared_decode_lookup_lut_indices(step, &decode_layout)?; + let bus = crate::memory_sidecar::memory::build_bus_layout_for_step_witness(step, t_len)?; + if bus.shout_cols.len() != step.lut_instances.len() { + return Err(PiCcsError::ProtocolError( + "W2(shared): bus layout shout lane count drift in WP fold".into(), + )); + } + let bus_base_delta = bus + .bus_base + .checked_sub(mcs_inst.m_in) + .ok_or_else(|| PiCcsError::ProtocolError("W2(shared): bus_base underflow in WP fold".into()))?; + if bus_base_delta % t_len != 0 { + return Err(PiCcsError::ProtocolError(format!( + "W2(shared): bus_base alignment mismatch in WP fold (bus_base_delta={}, t_len={t_len})", + bus_base_delta + ))); + } + let bus_col_offset = bus_base_delta / t_len; + for &lut_idx in decode_lut_indices.iter() { + let inst_cols = bus.shout_cols.get(lut_idx).ok_or_else(|| { + PiCcsError::ProtocolError( + "W2(shared): missing shout cols for decode lookup table in WP fold".into(), + ) + })?; + let lane0 = inst_cols.lanes.get(0).ok_or_else(|| { + PiCcsError::ProtocolError( + "W2(shared): expected one shout lane for decode lookup table in WP fold".into(), + ) + })?; + wp_open_cols.push(bus_col_offset + lane0.primary_val()); + } + } + if width_required { + wp_open_cols.extend(crate::memory_sidecar::memory::width_lookup_bus_val_cols_witness( + step, t_len, + )?); + } + let core_t = s.t(); + let m_in = mcs_inst.m_in; + let dec_wits = wb_wp_dec_wits + .as_ref() + .ok_or_else(|| PiCcsError::ProtocolError("WP fold missing shared DEC witnesses".into()))?; + let child_cs = wb_wp_child_cs + .as_ref() + .ok_or_else(|| PiCcsError::ProtocolError("WP fold missing shared DEC commitments".into()))?; + tr.append_message(b"fold/wp_lane_start", &(step_idx as u64).to_le_bytes()); + for (claim_idx, me) in mem_proof.wp_me_claims.iter().enumerate() { + tr.append_message(b"fold/wp_lane_claim_idx", &(claim_idx as u64).to_le_bytes()); + bind_rlc_inputs(tr, RlcLane::Val, step_idx, core::slice::from_ref(me))?; + let rlc_rhos = ccs::sample_rot_rhos_n(tr, params, &ring, 1)?; + let rlc_parent = ccs::rlc_public( + &s, + params, + &rlc_rhos, + core::slice::from_ref(me), + mixers.mix_rhos_commits, + ell_d, + )?; + let (mut dec_children, ok_y, ok_x, ok_c) = ccs::dec_children_with_commit_cached( + mode.clone(), + &s, + params, + &rlc_parent, + dec_wits, + ell_d, + child_cs, + mixers.combine_b_pows, + ccs_sparse_cache.as_deref(), + ); + if !(ok_y && ok_x && ok_c) { + return Err(PiCcsError::ProtocolError(format!( + "DEC(val) public check failed at step {} (y={}, X={}, c={})", + step_idx, ok_y, ok_x, ok_c + ))); + } + if dec_children.len() != dec_wits.len() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: WP fold requires materialized DEC witnesses (children={}, wits={})", + step_idx, + dec_children.len(), + dec_wits.len() + ))); + } + for (child, zi) in dec_children.iter_mut().zip(dec_wits.iter()) { + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + m_in, + &wp_open_cols, + core_t, + zi, + child, + )?; + } + if collect_val_lane_wits { + val_lane_wits.extend(dec_wits.iter().cloned()); + } + wp_fold.push(RlcDecProof { + rlc_rhos, + rlc_parent, + dec_children, + }); + } + } + + accumulator = children.clone(); + accumulator_wit = if want_main_wits { Z_split } else { Vec::new() }; + + step_proofs.push(StepProof { + fold: FoldStep { + ccs_out, + ccs_proof, + rlc_rhos: rhos, + rlc_parent: parent_pub, + dec_children: children, + }, + mem: mem_proof, + batched_time, + val_fold, + wb_fold, + wp_fold, + }); + + tr.append_message(b"fold/step_done", &(step_idx as u64).to_le_bytes()); + if let Some(out) = step_prove_ms_out.as_deref_mut() { + out.push(elapsed_ms(step_start)); + } + } + + Ok(( + ShardProof { + steps: step_proofs, + output_proof, + }, + accumulator_wit, + val_lane_wits, + )) +} diff --git a/crates/neo-fold/src/shard/rlc_dec.rs b/crates/neo-fold/src/shard/rlc_dec.rs new file mode 100644 index 00000000..49f394c3 --- /dev/null +++ b/crates/neo-fold/src/shard/rlc_dec.rs @@ -0,0 +1,948 @@ +use super::*; + +pub(crate) fn prove_rlc_dec_lane( + mode: &FoldingMode, + lane: RlcLane, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s: &CcsStructure, + ccs_sparse_cache: Option<&SparseCache>, + cpu_bus: Option<&neo_memory::cpu::BusLayout>, + ring: &ccs::RotRing, + ell_d: usize, + k_dec: usize, + step_idx: usize, + trace_linkage_t_len: Option, + me_inputs: &[MeInstance], + wit_inputs: &[&Mat], + want_witnesses: bool, + l: &L, + mixers: CommitMixers, +) -> Result<(RlcDecProof, Vec>), PiCcsError> +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + if me_inputs.is_empty() { + let prefix = match lane { + RlcLane::Main => "", + RlcLane::Val => "val-lane ", + }; + return Err(PiCcsError::InvalidInput(format!( + "step {}: {prefix}RLC input batch is empty", + step_idx + ))); + } + if wit_inputs.len() != me_inputs.len() { + let prefix = match lane { + RlcLane::Main => "", + RlcLane::Val => "val-lane ", + }; + return Err(PiCcsError::InvalidInput(format!( + "step {}: {prefix}RLC witness count mismatch (me_inputs.len()={}, wit_inputs.len()={})", + step_idx, + me_inputs.len(), + wit_inputs.len() + ))); + } + + bind_rlc_inputs(tr, lane, step_idx, me_inputs)?; + let rlc_rhos = ccs::sample_rot_rhos_n(tr, params, ring, me_inputs.len())?; + let (mut rlc_parent, Z_mix) = if me_inputs.len() == 1 { + if rlc_rhos.len() != 1 { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_RLC(k=1): |rhos| must equal |inputs|", + step_idx + ))); + } + let inp = &me_inputs[0]; + + // Match `neo_reductions::api::rlc_with_commit` semantics for k=1 without cloning Z. + let inputs_c = vec![inp.c.clone()]; + let c = (mixers.mix_rhos_commits)(&rlc_rhos, &inputs_c); + + let t = inp.y.len(); + if t < s.t() { + return Err(PiCcsError::InvalidInput(format!( + "step {}: Π_RLC(k=1): ME y.len() must be >= s.t() (got {}, s.t()={})", + step_idx, + t, + s.t() + ))); + } + for (j, row) in inp.y.iter().enumerate() { + if row.len() < D { + return Err(PiCcsError::InvalidInput(format!( + "step {}: Π_RLC(k=1): ME y[{}].len()={} must be >= D={}", + step_idx, + j, + row.len(), + D + ))); + } + } + verify_me_y_scalars_canonical(inp, params.b, step_idx, "Π_RLC(k=1)")?; + + let out = MeInstance:: { + c_step_coords: vec![], + u_offset: 0, + u_len: 0, + c, + X: inp.X.clone(), + r: inp.r.clone(), + s_col: inp.s_col.clone(), + y: inp.y.clone(), + y_scalars: inp.y_scalars.clone(), + y_zcol: inp.y_zcol.clone(), + m_in: inp.m_in, + fold_digest: inp.fold_digest, + }; + + (out, Cow::Borrowed(wit_inputs[0])) + } else { + let (out, Z_mix) = { + #[cfg(feature = "paper-exact")] + { + if matches!(mode, FoldingMode::PaperExact) { + // Keep paper-exact dispatch through the public API. + let wit_owned: Vec> = wit_inputs.iter().map(|m| (*m).clone()).collect(); + ccs::rlc_with_commit( + mode.clone(), + s, + params, + &rlc_rhos, + me_inputs, + &wit_owned, + ell_d, + mixers.mix_rhos_commits, + )? + } else { + neo_reductions::optimized_engine::rlc_reduction_optimized_with_commit_mix( + s, + params, + &rlc_rhos, + me_inputs, + wit_inputs, + ell_d, + mixers.mix_rhos_commits, + ) + } + } + #[cfg(not(feature = "paper-exact"))] + { + neo_reductions::optimized_engine::rlc_reduction_optimized_with_commit_mix( + s, + params, + &rlc_rhos, + me_inputs, + wit_inputs, + ell_d, + mixers.mix_rhos_commits, + ) + } + }; + (out, Cow::Owned(Z_mix)) + }; + + let Z_mix = Z_mix.as_ref(); + + let inputs_have_extra_y = me_inputs.iter().any(|me| me.y.len() > s.t()); + let can_stream_dec = !want_witnesses + && has_global_pp_for_dims(D, s.m) + && !cpu_bus.map(|b| b.bus_cols > 0).unwrap_or(false) + && !inputs_have_extra_y; + + let materialize_dec = || -> Result<(Vec>, bool, bool, bool, Vec>), PiCcsError> { + // Standard DEC: materialize digit matrices (needed when carrying witnesses forward). + let (Z_split, digit_nonzero) = ccs::split_b_matrix_k_with_nonzero_flags(Z_mix, k_dec, params.b)?; + let zero_c = Cmt::zeros(rlc_parent.c.d, rlc_parent.c.kappa); + let child_cs: Vec = { + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + const PAR_CHILD_COMMIT_THRESHOLD: usize = 32; + let use_parallel = Z_split.len() >= PAR_CHILD_COMMIT_THRESHOLD && rayon::current_num_threads() > 1; + if use_parallel { + Z_split + .par_iter() + .enumerate() + .map(|(idx, Zi)| { + if digit_nonzero[idx] { + l.commit(Zi) + } else { + zero_c.clone() + } + }) + .collect() + } else { + Z_split + .iter() + .enumerate() + .map(|(idx, Zi)| { + if digit_nonzero[idx] { + l.commit(Zi) + } else { + zero_c.clone() + } + }) + .collect() + } + } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + Z_split + .iter() + .enumerate() + .map(|(idx, Zi)| { + if digit_nonzero[idx] { + l.commit(Zi) + } else { + zero_c.clone() + } + }) + .collect() + } + }; + let (dec_children, ok_y, ok_X, ok_c) = ccs::dec_children_with_commit_cached( + mode.clone(), + s, + params, + &rlc_parent, + &Z_split, + ell_d, + &child_cs, + mixers.combine_b_pows, + ccs_sparse_cache, + ); + Ok((dec_children, ok_y, ok_X, ok_c, Z_split)) + }; + + let (mut dec_children, ok_y, ok_X, ok_c, maybe_wits) = if can_stream_dec { + // Memory-optimized DEC: compute children + commitments without materializing Z_split. + // If public consistency checks fail (e.g. global PP mismatch vs local committer), + // fall back to the materialized path for correctness. + let (children, _child_cs, ok_y, ok_X, ok_c) = dec_stream_no_witness( + params, + s, + &rlc_parent, + Z_mix, + ell_d, + k_dec, + mixers.combine_b_pows, + ccs_sparse_cache, + )?; + if ok_y && ok_X && ok_c { + (children, ok_y, ok_X, ok_c, Vec::new()) + } else { + materialize_dec()? + } + } else { + materialize_dec()? + }; + if !(ok_y && ok_X && ok_c) { + let lane_label = match lane { + RlcLane::Main => "DEC", + RlcLane::Val => "DEC(val)", + }; + return Err(PiCcsError::ProtocolError(format!( + "{} public check failed at step {} (y={}, X={}, c={})", + lane_label, step_idx, ok_y, ok_X, ok_c + ))); + } + + // Shared CPU bus: carry the implicit bus openings through Π_RLC/Π_DEC so they remain + // part of the folded instance (and are checked by public DEC verification). + if let Some(bus) = cpu_bus { + if bus.bus_cols > 0 { + let core_t = s.t(); + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( + params, + bus, + core_t, + Z_mix, + &mut rlc_parent, + )?; + for (child, Zi) in dec_children.iter_mut().zip(maybe_wits.iter()) { + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance(params, bus, core_t, Zi, child)?; + } + } + } + + // If the main lane carries RV32 trace linkage openings, propagate them through Π_DEC so child + // instances keep the same extra y/y_scalars length (after optional shared-bus openings). + if matches!(lane, RlcLane::Main) && trace_linkage_t_len.is_some() { + let core_t = s.t(); + let trace_open_base = core_t + cpu_bus.map_or(0usize, |bus| bus.bus_cols); + let trace = Rv32TraceLayout::new(); + let trace_cols_to_open: Vec = vec![ + trace.active, + trace.cycle, + trace.pc_before, + trace.instr_word, + trace.rs1_addr, + trace.rs1_val, + trace.rs2_addr, + trace.rs2_val, + trace.rd_addr, + trace.rd_val, + trace.ram_addr, + trace.ram_rv, + trace.ram_wv, + trace.shout_has_lookup, + trace.shout_val, + trace.shout_lhs, + trace.shout_rhs, + ]; + + let want_len = trace_open_base + trace_cols_to_open.len(); + let has_base_only = rlc_parent.y.len() == trace_open_base && rlc_parent.y_scalars.len() == trace_open_base; + let has_trace_openings = rlc_parent.y.len() == want_len && rlc_parent.y_scalars.len() == want_len; + if has_base_only || has_trace_openings { + let m_in = rlc_parent.m_in; + if m_in != 5 { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage openings expect m_in=5 (got {m_in})" + ))); + } + let t_len = trace_linkage_t_len + .ok_or_else(|| PiCcsError::ProtocolError("trace linkage openings require explicit t_len".into()))?; + if t_len == 0 { + return Err(PiCcsError::InvalidInput("trace linkage expects t_len >= 1".into())); + } + let trace_len = trace + .cols + .checked_mul(t_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace cols * t_len overflow".into()))?; + let min_m = m_in + .checked_add(trace_len) + .ok_or_else(|| PiCcsError::InvalidInput("m_in + trace_len overflow".into()))?; + if s.m < min_m { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage openings require m >= m_in + trace.cols*t_len (m={}, min_m={} for t_len={}, trace_cols={})", + s.m, min_m, t_len, trace.cols + ))); + } + + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + /*col_base=*/ m_in, + &trace_cols_to_open, + trace_open_base, + Z_mix, + &mut rlc_parent, + )?; + if dec_children.len() != maybe_wits.len() { + return Err(PiCcsError::ProtocolError( + "trace linkage requires materialized DEC witnesses".into(), + )); + } + for (child, Zi) in dec_children.iter_mut().zip(maybe_wits.iter()) { + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + /*col_base=*/ m_in, + &trace_cols_to_open, + trace_open_base, + Zi, + child, + )?; + } + } else { + return Err(PiCcsError::InvalidInput(format!( + "trace linkage openings expect parent y/y_scalars len to be base={} or base+trace_openings={} (got y.len()={}, y_scalars.len()={})", + trace_open_base, + want_len, + rlc_parent.y.len(), + rlc_parent.y_scalars.len(), + ))); + } + } + + Ok(( + RlcDecProof { + rlc_rhos, + rlc_parent, + dec_children, + }, + maybe_wits, + )) +} + +pub(crate) fn verify_rlc_dec_lane( + lane: RlcLane, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s: &CcsStructure, + ring: &ccs::RotRing, + ell_d: usize, + mixers: CommitMixers, + step_idx: usize, + rlc_inputs: &[MeInstance], + rlc_rhos: &[Mat], + rlc_parent: &MeInstance, + dec_children: &[MeInstance], +) -> Result<(), PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + bind_rlc_inputs(tr, lane, step_idx, rlc_inputs)?; + + if rlc_rhos.len() != rlc_inputs.len() { + let prefix = match lane { + RlcLane::Main => "", + RlcLane::Val => "val-lane ", + }; + return Err(PiCcsError::InvalidInput(format!( + "step {}: {}RLC ρ count mismatch (expected {}, got {})", + step_idx, + prefix, + rlc_inputs.len(), + rlc_rhos.len() + ))); + } + + for (i, me) in rlc_inputs.iter().enumerate() { + verify_me_y_scalars_canonical( + me, + params.b, + step_idx, + &format!( + "{}RLC input[{i}]", + match lane { + RlcLane::Main => "", + RlcLane::Val => "val-lane ", + } + ), + )?; + } + + let rhos_from_tr = ccs::sample_rot_rhos_n(tr, params, ring, rlc_inputs.len())?; + for (j, (sampled, stored)) in rhos_from_tr.iter().zip(rlc_rhos.iter()).enumerate() { + if sampled.as_slice() != stored.as_slice() { + return Err(PiCcsError::ProtocolError(match lane { + RlcLane::Main => format!("step {}: RLC ρ #{} mismatch: transcript vs proof", step_idx, j), + RlcLane::Val => format!("step {}: val-lane RLC ρ #{} mismatch: transcript vs proof", step_idx, j), + })); + } + } + + let parent_pub = ccs::rlc_public(s, params, rlc_rhos, rlc_inputs, mixers.mix_rhos_commits, ell_d)?; + + let prefix = match lane { + RlcLane::Main => "", + RlcLane::Val => "val-lane ", + }; + if parent_pub.m_in != rlc_parent.m_in { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC m_in mismatch (public={}, proof={})", + step_idx, parent_pub.m_in, rlc_parent.m_in + ))); + } + if parent_pub.fold_digest != rlc_parent.fold_digest { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC fold_digest mismatch", + step_idx + ))); + } + if parent_pub.c_step_coords != rlc_parent.c_step_coords { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC c_step_coords mismatch", + step_idx + ))); + } + if parent_pub.u_offset != rlc_parent.u_offset { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC u_offset mismatch", + step_idx + ))); + } + if parent_pub.u_len != rlc_parent.u_len { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC u_len mismatch", + step_idx + ))); + } + if parent_pub.X != rlc_parent.X { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC X mismatch", + step_idx + ))); + } + if parent_pub.c != rlc_parent.c { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC commitment mismatch", + step_idx + ))); + } + if parent_pub.r != rlc_parent.r { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC r mismatch", + step_idx + ))); + } + if parent_pub.s_col != rlc_parent.s_col { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC s_col mismatch", + step_idx + ))); + } + if parent_pub.y != rlc_parent.y { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC y mismatch", + step_idx + ))); + } + if parent_pub.y_scalars != rlc_parent.y_scalars { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC y_scalars mismatch", + step_idx + ))); + } + if parent_pub.y_zcol != rlc_parent.y_zcol { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC y_zcol mismatch", + step_idx + ))); + } + + if rlc_parent.X.rows() != D || rlc_parent.X.cols() != rlc_parent.m_in { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}RLC parent X shape {}x{} does not match m_in={}", + step_idx, + rlc_parent.X.rows(), + rlc_parent.X.cols(), + rlc_parent.m_in + ))); + } + if !dec_children.is_empty() { + validate_me_batch_invariants(dec_children, "verify step dec children")?; + for (child_idx, child) in dec_children.iter().enumerate() { + if child.m_in != rlc_parent.m_in { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}DEC child[{child_idx}] has m_in={}, expected {}", + step_idx, child.m_in, rlc_parent.m_in + ))); + } + if child.fold_digest != rlc_parent.fold_digest { + return Err(PiCcsError::ProtocolError(format!( + "step {}: {prefix}DEC child[{child_idx}] fold_digest mismatch", + step_idx + ))); + } + } + } + + if !ccs::verify_dec_public(s, params, rlc_parent, dec_children, mixers.combine_b_pows, ell_d) { + return Err(PiCcsError::ProtocolError(match lane { + RlcLane::Main => format!("step {}: DEC public check failed", step_idx), + RlcLane::Val => format!("step {}: val-lane DEC public check failed", step_idx), + })); + } + + Ok(()) +} + +#[cfg(feature = "paper-exact")] +pub(crate) fn crosscheck_route_a_ccs_step( + cfg: &neo_reductions::engines::CrosscheckCfg, + step_idx: usize, + params: &NeoParams, + s: &CcsStructure, + cpu_bus: &neo_memory::cpu::BusLayout, + mcs_inst: &neo_ccs::McsInstance, + mcs_wit: &neo_ccs::McsWitness, + me_inputs: &[MeInstance], + me_witnesses: &[Mat], + ccs_out: &[MeInstance], + ccs_proof: &crate::PiCcsProof, + ell_d: usize, + ell_n: usize, + ell_m: usize, + d_sc: usize, + fold_digest: [u8; 32], + log: &L, +) -> Result<(), PiCcsError> +where + L: SModuleHomomorphism + Sync, +{ + let want_rounds_total = ell_n + .checked_add(ell_d) + .ok_or_else(|| PiCcsError::ProtocolError("ell_n + ell_d overflow".into()))?; + if ccs_proof.sumcheck_rounds.len() != want_rounds_total { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck expects {} CCS sumcheck rounds, got {}", + step_idx, + want_rounds_total, + ccs_proof.sumcheck_rounds.len(), + ))); + } + if ccs_proof.sumcheck_challenges.len() != want_rounds_total { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck expects {} CCS sumcheck challenges, got {}", + step_idx, + want_rounds_total, + ccs_proof.sumcheck_challenges.len(), + ))); + } + let (s_col_prime, alpha_prime_nc) = if ccs_proof.variant == crate::optimized_engine::PiCcsProofVariant::SplitNcV1 { + let want_nc_rounds_total = ell_m + .checked_add(ell_d) + .ok_or_else(|| PiCcsError::ProtocolError("ell_m + ell_d overflow".into()))?; + if ccs_proof.sumcheck_rounds_nc.len() != want_nc_rounds_total { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck expects {} NC sumcheck rounds, got {}", + step_idx, + want_nc_rounds_total, + ccs_proof.sumcheck_rounds_nc.len(), + ))); + } + if ccs_proof.sumcheck_challenges_nc.len() != want_nc_rounds_total { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck expects {} NC sumcheck challenges, got {}", + step_idx, + want_nc_rounds_total, + ccs_proof.sumcheck_challenges_nc.len(), + ))); + } + ccs_proof.sumcheck_challenges_nc.split_at(ell_m) + } else { + (&[][..], &[][..]) + }; + + let (r_prime, alpha_prime) = ccs_proof.sumcheck_challenges.split_at(ell_n); + let r_inputs = me_inputs.first().map(|mi| mi.r.as_slice()); + + // Crosscheck initial-sum parity is most informative once there is at least one carried ME + // input. For empty-accumulator starts, optimized and paper-exact route through different + // constant-term paths and can diverge without indicating a soundness issue. + if cfg.initial_sum && !me_inputs.is_empty() { + let lhs_exact = crate::paper_exact_engine::sum_q_over_hypercube_paper_exact( + s, + params, + core::slice::from_ref(mcs_wit), + me_witnesses, + &ccs_proof.challenges_public, + ell_d, + ell_n, + r_inputs, + ); + let initial_sum_prover = ccs_proof + .sumcheck_rounds + .first() + .map(|p0| poly_eval_k(p0, K::ZERO) + poly_eval_k(p0, K::ONE)) + .ok_or_else(|| PiCcsError::ProtocolError("crosscheck: missing sumcheck round 0".into()))?; + if lhs_exact != initial_sum_prover { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck initial sum mismatch (optimized vs paper-exact)", + step_idx + ))); + } + } + + if cfg.per_round { + let mut paper_oracle = crate::paper_exact_engine::oracle::PaperExactOracle::new( + s, + params, + core::slice::from_ref(mcs_wit), + me_witnesses, + ccs_proof.challenges_public.clone(), + ell_d, + ell_n, + d_sc, + r_inputs, + ); + + let mut any_mismatch = false; + for (round_idx, (opt_coeffs, &challenge)) in ccs_proof + .sumcheck_rounds + .iter() + .zip(ccs_proof.sumcheck_challenges.iter()) + .enumerate() + { + let deg = paper_oracle.degree_bound(); + let xs: Vec = (0..=deg).map(|t| K::from(F::from_u64(t as u64))).collect(); + let paper_evals = paper_oracle.evals_at(&xs); + + for (&x, &expected) in xs.iter().zip(paper_evals.iter()) { + let actual = poly_eval_k(opt_coeffs, x); + if actual != expected { + any_mismatch = true; + if cfg.fail_fast { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck round {} polynomial mismatch", + step_idx, round_idx + ))); + } + } + } + + paper_oracle.fold(challenge); + } + if any_mismatch { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck per-round polynomial mismatch", + step_idx + ))); + } + } + + if cfg.terminal { + let running_sum_prover = if let Some(initial) = ccs_proof.sc_initial_sum { + let mut running = initial; + for (coeffs, &ri) in ccs_proof + .sumcheck_rounds + .iter() + .zip(ccs_proof.sumcheck_challenges.iter()) + { + running = poly_eval_k(coeffs, ri); + } + running + } else { + ccs_proof + .sumcheck_rounds + .first() + .map(|p0| poly_eval_k(p0, K::ZERO) + poly_eval_k(p0, K::ONE)) + .unwrap_or(K::ZERO) + }; + + let rhs_fe = crate::paper_exact_engine::rhs_terminal_identity_fe_paper_exact( + s, + params, + &ccs_proof.challenges_public, + r_prime, + alpha_prime, + ccs_out, + r_inputs, + ); + let (lhs_fe, _rhs_unused) = crate::paper_exact_engine::q_eval_at_ext_point_fe_paper_exact_with_inputs( + s, + params, + core::slice::from_ref(mcs_wit), + me_witnesses, + alpha_prime, + r_prime, + &ccs_proof.challenges_public, + r_inputs, + ); + if rhs_fe != lhs_fe || rhs_fe != running_sum_prover { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck FE terminal evaluation claim mismatch", + step_idx + ))); + } + + let rhs_nc = crate::paper_exact_engine::rhs_terminal_identity_nc_paper_exact( + params, + &ccs_proof.challenges_public, + s_col_prime, + alpha_prime_nc, + ccs_out, + ); + if rhs_nc != ccs_proof.sumcheck_final_nc { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck NC terminal evaluation claim mismatch", + step_idx + ))); + } + } + + if cfg.outputs { + let mut out_me_ref = build_me_outputs_paper_exact( + s, + params, + core::slice::from_ref(mcs_inst), + core::slice::from_ref(mcs_wit), + me_inputs, + me_witnesses, + r_prime, + s_col_prime, + ell_d, + fold_digest, + log, + ); + + if cpu_bus.bus_cols > 0 { + let core_t = s.t(); + if out_me_ref.len() != 1 + me_witnesses.len() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck CCS output count mismatch for bus openings (out_me_ref.len()={}, expected {})", + step_idx, + out_me_ref.len(), + 1 + me_witnesses.len() + ))); + } + + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance( + params, + cpu_bus, + core_t, + &mcs_wit.Z, + &mut out_me_ref[0], + )?; + for (out, Z) in out_me_ref.iter_mut().skip(1).zip(me_witnesses.iter()) { + crate::memory_sidecar::cpu_bus::append_bus_openings_to_me_instance(params, cpu_bus, core_t, Z, out)?; + } + + let trace = Rv32TraceLayout::new(); + let trace_cols_to_open: Vec = vec![ + trace.active, + trace.cycle, + trace.pc_before, + trace.instr_word, + trace.rs1_addr, + trace.rs1_val, + trace.rs2_addr, + trace.rs2_val, + trace.rd_addr, + trace.rd_val, + trace.ram_addr, + trace.ram_rv, + trace.ram_wv, + trace.shout_has_lookup, + trace.shout_val, + trace.shout_lhs, + trace.shout_rhs, + ]; + let want_with_trace = core_t + cpu_bus.bus_cols + trace_cols_to_open.len(); + if ccs_out + .first() + .map(|me| me.y_scalars.len() == want_with_trace) + .unwrap_or(false) + { + let m_in = mcs_inst.m_in; + let bus_region_len = cpu_bus + .bus_cols + .checked_mul(cpu_bus.chunk_size) + .ok_or_else(|| PiCcsError::ProtocolError("crosscheck bus region overflow".into()))?; + let trace_region = + s.m.checked_sub(m_in) + .and_then(|v| v.checked_sub(bus_region_len)) + .ok_or_else(|| PiCcsError::ProtocolError("crosscheck trace region underflow".into()))?; + if trace.cols == 0 || trace_region % trace.cols != 0 { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck cannot infer trace t_len (trace_region={}, trace_cols={})", + step_idx, trace_region, trace.cols + ))); + } + let t_len = trace_region / trace.cols; + let trace_open_base = core_t + cpu_bus.bus_cols; + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + m_in, + &trace_cols_to_open, + trace_open_base, + &mcs_wit.Z, + &mut out_me_ref[0], + )?; + for (out, Z) in out_me_ref.iter_mut().skip(1).zip(me_witnesses.iter()) { + crate::memory_sidecar::cpu_bus::append_col_major_time_openings_to_me_instance( + params, + m_in, + t_len, + m_in, + &trace_cols_to_open, + trace_open_base, + Z, + out, + )?; + } + } + } + + if out_me_ref.len() != ccs_out.len() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output length mismatch (paper={}, optimized={})", + step_idx, + out_me_ref.len(), + ccs_out.len() + ))); + } + + for (idx, (a, b)) in out_me_ref.iter().zip(ccs_out.iter()).enumerate() { + if a.m_in != b.m_in { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] m_in mismatch (paper={}, optimized={})", + step_idx, a.m_in, b.m_in + ))); + } + if a.r != b.r { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] r mismatch", + step_idx + ))); + } + if a.s_col != b.s_col { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] s_col mismatch", + step_idx + ))); + } + if a.c.data != b.c.data { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] commitment mismatch", + step_idx + ))); + } + if a.y.len() != b.y.len() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] y.len mismatch (paper={}, optimized={})", + step_idx, + a.y.len(), + b.y.len() + ))); + } + for (j, (ya, yb)) in a.y.iter().zip(b.y.iter()).enumerate() { + if ya != yb { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] y row {j} mismatch", + step_idx + ))); + } + } + if a.y_scalars != b.y_scalars { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] y_scalars mismatch", + step_idx + ))); + } + if a.y_zcol != b.y_zcol { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] y_zcol mismatch", + step_idx + ))); + } + if a.X.rows() != b.X.rows() || a.X.cols() != b.X.cols() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] X dims mismatch (paper={}x{}, optimized={}x{})", + step_idx, + a.X.rows(), + a.X.cols(), + b.X.rows(), + b.X.cols() + ))); + } + for r in 0..a.X.rows() { + for c in 0..a.X.cols() { + if a.X[(r, c)] != b.X[(r, c)] { + return Err(PiCcsError::ProtocolError(format!( + "step {}: crosscheck output[{idx}] X mismatch at ({},{})", + step_idx, r, c + ))); + } + } + } + } + } + + Ok(()) +} + +// ============================================================================ +// Shard Proving +// ============================================================================ diff --git a/crates/neo-fold/src/shard/verifier_and_api.rs b/crates/neo-fold/src/shard/verifier_and_api.rs new file mode 100644 index 00000000..e2e37eee --- /dev/null +++ b/crates/neo-fold/src/shard/verifier_and_api.rs @@ -0,0 +1,1433 @@ +use super::*; + +pub fn fold_shard_prove( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepWitnessBundle], + acc_init: &[MeInstance], + acc_wit_init: &[Mat], + l: &L, + mixers: CommitMixers, +) -> Result +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( + false, + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + acc_wit_init, + l, + mixers, + None, + None, + None, + )?; + Ok(proof) +} + +pub(crate) fn fold_shard_prove_with_context( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepWitnessBundle], + acc_init: &[MeInstance], + acc_wit_init: &[Mat], + l: &L, + mixers: CommitMixers, + ctx: &ShardProverContext, +) -> Result +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( + false, + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + acc_wit_init, + l, + mixers, + None, + Some(ctx), + None, + )?; + Ok(proof) +} + +pub(crate) fn fold_shard_prove_with_context_and_step_timings( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepWitnessBundle], + acc_init: &[MeInstance], + acc_wit_init: &[Mat], + l: &L, + mixers: CommitMixers, + ctx: &ShardProverContext, +) -> Result<(ShardProof, Vec), PiCcsError> +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + let mut step_prove_ms = Vec::with_capacity(steps.len()); + let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( + false, + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + acc_wit_init, + l, + mixers, + None, + Some(ctx), + Some(&mut step_prove_ms), + )?; + Ok((proof, step_prove_ms)) +} + +pub fn fold_shard_prove_with_output_binding( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepWitnessBundle], + acc_init: &[MeInstance], + acc_wit_init: &[Mat], + l: &L, + mixers: CommitMixers, + ob_cfg: &crate::output_binding::OutputBindingConfig, + final_memory_state: &[F], +) -> Result +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( + false, + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + acc_wit_init, + l, + mixers, + Some((ob_cfg, final_memory_state)), + None, + None, + )?; + Ok(proof) +} + +pub(crate) fn fold_shard_prove_with_output_binding_with_context( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepWitnessBundle], + acc_init: &[MeInstance], + acc_wit_init: &[Mat], + l: &L, + mixers: CommitMixers, + ob_cfg: &crate::output_binding::OutputBindingConfig, + final_memory_state: &[F], + ctx: &ShardProverContext, +) -> Result +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + let (proof, _final_main_wits, _val_lane_wits) = fold_shard_prove_impl( + false, + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + acc_wit_init, + l, + mixers, + Some((ob_cfg, final_memory_state)), + Some(ctx), + None, + )?; + Ok(proof) +} + +pub fn fold_shard_prove_with_witnesses( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepWitnessBundle], + acc_init: &[MeInstance], + acc_wit_init: &[Mat], + l: &L, + mixers: CommitMixers, +) -> Result<(ShardProof, ShardFoldOutputs, ShardFoldWitnesses), PiCcsError> +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + let (proof, final_main_wits, val_lane_wits) = fold_shard_prove_impl( + true, + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + acc_wit_init, + l, + mixers, + None, + None, + None, + )?; + let outputs = proof.compute_fold_outputs(acc_init); + if outputs.obligations.main.len() != final_main_wits.len() { + return Err(PiCcsError::ProtocolError(format!( + "final main witness count mismatch (have {}, need {})", + final_main_wits.len(), + outputs.obligations.main.len() + ))); + } + if outputs.obligations.val.len() != val_lane_wits.len() { + return Err(PiCcsError::ProtocolError(format!( + "val-lane witness count mismatch (have {}, need {})", + val_lane_wits.len(), + outputs.obligations.val.len() + ))); + } + Ok(( + proof, + outputs, + ShardFoldWitnesses { + final_main_wits, + val_lane_wits, + }, + )) +} + +/// Same as `fold_shard_prove_with_witnesses`, but offsets the per-step transcript index by `step_idx_offset`. +/// +/// This is useful for "continuation" style proving across multiple calls while preserving a globally +/// increasing step index for transcript domain separation. +pub fn fold_shard_prove_with_witnesses_with_step_offset( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepWitnessBundle], + acc_init: &[MeInstance], + acc_wit_init: &[Mat], + l: &L, + mixers: CommitMixers, + step_idx_offset: usize, +) -> Result<(ShardProof, ShardFoldOutputs, ShardFoldWitnesses), PiCcsError> +where + L: SModuleHomomorphism + Sync, + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + let (proof, final_main_wits, val_lane_wits) = fold_shard_prove_impl( + true, + mode, + tr, + params, + s_me, + steps, + step_idx_offset, + acc_init, + acc_wit_init, + l, + mixers, + None, + None, + None, + )?; + let outputs = proof.compute_fold_outputs(acc_init); + if outputs.obligations.main.len() != final_main_wits.len() { + return Err(PiCcsError::ProtocolError(format!( + "final main witness count mismatch (have {}, need {})", + final_main_wits.len(), + outputs.obligations.main.len() + ))); + } + if outputs.obligations.val.len() != val_lane_wits.len() { + return Err(PiCcsError::ProtocolError(format!( + "val-lane witness count mismatch (have {}, need {})", + val_lane_wits.len(), + outputs.obligations.val.len() + ))); + } + Ok(( + proof, + outputs, + ShardFoldWitnesses { + final_main_wits, + val_lane_wits, + }, + )) +} + +// ============================================================================ +// Shard Verification +// ============================================================================ + +pub(crate) fn fold_shard_verify_impl( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + step_idx_offset: usize, + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + ob_cfg: Option<&crate::output_binding::OutputBindingConfig>, + prover_ctx: Option<&ShardProverContext>, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + for (step_idx, step) in steps.iter().enumerate() { + if step.lut_insts.is_empty() && step.mem_insts.is_empty() { + continue; + } + let is_shared_step = step.lut_insts.iter().all(|inst| inst.comms.is_empty()) + && step.mem_insts.iter().all(|inst| inst.comms.is_empty()); + if !is_shared_step { + return Err(PiCcsError::InvalidInput(format!( + "legacy no-shared CPU bus mode was removed; step_idx={step_idx} must use shared-bus statement format" + ))); + } + } + tr.append_message(b"shard/cpu_bus_mode", &[1u8]); + let (s, cpu_bus) = crate::memory_sidecar::cpu_bus::prepare_ccs_for_shared_cpu_bus_steps(s_me, steps)?; + let dims = utils::build_dims_and_policy(params, s)?; + let utils::Dims { + ell_d, + ell_n, + ell_m, + ell, + d_sc, + .. + } = dims; + let ring = ccs::RotRing::goldilocks(); + + if steps.len() != proof.steps.len() { + return Err(PiCcsError::InvalidInput(format!( + "step count mismatch: public {} vs proof {}", + steps.len(), + proof.steps.len() + ))); + } + if ob_cfg.is_some() && steps.is_empty() { + return Err(PiCcsError::InvalidInput("output binding requires >= 1 step".into())); + } + if ob_cfg.is_none() && proof.output_proof.is_some() { + return Err(PiCcsError::InvalidInput( + "shard proof contains output binding, but verifier did not supply OutputBindingConfig".into(), + )); + } + if ob_cfg.is_some() && proof.output_proof.is_none() { + return Err(PiCcsError::InvalidInput( + "verifier supplied OutputBindingConfig, but shard proof has no output binding".into(), + )); + } + + let mut accumulator = acc_init.to_vec(); + let mut val_lane_obligations: Vec> = Vec::new(); + let ccs_sparse_cache: Option>> = if mode_uses_sparse_cache(&mode) { + Some( + prover_ctx + .and_then(|ctx| ctx.ccs_sparse_cache.clone()) + .unwrap_or_else(|| Arc::new(SparseCache::build(s))), + ) + } else { + None + }; + let ccs_mat_digest = prover_ctx + .map(|ctx| ctx.ccs_mat_digest.clone()) + .unwrap_or_else(|| utils::digest_ccs_matrices_with_sparse_cache(s, ccs_sparse_cache.as_deref())); + + for (idx, (step, step_proof)) in steps.iter().zip(proof.steps.iter()).enumerate() { + let step_idx = step_idx_offset + .checked_add(idx) + .ok_or_else(|| PiCcsError::InvalidInput("step index overflow".into()))?; + let has_prev = idx > 0; + absorb_step_memory(tr, step); + + let include_ob = ob_cfg.is_some() && (idx + 1 == steps.len()); + let mut ob_state: Option = None; + let mut ob_inc_total_degree_bound: Option = None; + + if include_ob { + let cfg = + ob_cfg.ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but config missing".into()))?; + let ob_proof = proof + .output_proof + .as_ref() + .ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but proof missing".into()))?; + + if cfg.mem_idx >= step.mem_insts.len() { + return Err(PiCcsError::InvalidInput("output binding mem_idx out of range".into())); + } + let mem_inst = step + .mem_insts + .get(cfg.mem_idx) + .ok_or_else(|| PiCcsError::InvalidInput("output binding mem_idx out of range".into()))?; + let expected_k = 1usize + .checked_shl(cfg.num_bits as u32) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^num_bits overflow".into()))?; + if mem_inst.k != expected_k { + return Err(PiCcsError::InvalidInput(format!( + "output binding: cfg.num_bits implies k={}, but mem_inst.k={}", + expected_k, mem_inst.k + ))); + } + let ell_addr = mem_inst.twist_layout().lanes[0].ell_addr; + if ell_addr != cfg.num_bits { + return Err(PiCcsError::InvalidInput(format!( + "output binding: cfg.num_bits={}, but twist_layout.ell_addr={}", + cfg.num_bits, ell_addr + ))); + } + + tr.append_message(b"shard/output_binding_start", &(step_idx as u64).to_le_bytes()); + tr.append_u64s(b"output_binding/mem_idx", &[cfg.mem_idx as u64]); + tr.append_u64s(b"output_binding/num_bits", &[cfg.num_bits as u64]); + + let state = neo_memory::output_check::verify_output_sumcheck_rounds_get_state( + tr, + cfg.num_bits, + cfg.program_io.clone(), + &ob_proof.output_sc, + ) + .map_err(|e| PiCcsError::ProtocolError(format!("output sumcheck failed: {e:?}")))?; + ob_inc_total_degree_bound = Some(2 + ell_addr); + ob_state = Some(state); + } + + let mcs_inst = &step.mcs_inst; + + // -------------------------------------------------------------------- + // Route A: Verify shared-challenge batched sum-check (time/row rounds), + // then finish CCS Ajtai rounds, then proceed with RLC→DEC as before. + // -------------------------------------------------------------------- + + // Bind CCS header + ME inputs and sample public challenges. + utils::bind_header_and_instances_with_digest( + tr, + params, + &s, + core::slice::from_ref(mcs_inst), + dims, + &ccs_mat_digest, + )?; + utils::bind_me_inputs(tr, &accumulator)?; + let mut ch = utils::sample_challenges(tr, ell_d, ell)?; + if step_proof.fold.ccs_proof.variant == crate::optimized_engine::PiCcsProofVariant::SplitNcV1 { + ch.beta_m = utils::sample_beta_m(tr, ell_m)?; + } + let expected_ch = &step_proof.fold.ccs_proof.challenges_public; + if expected_ch.alpha != ch.alpha + || expected_ch.beta_a != ch.beta_a + || expected_ch.beta_r != ch.beta_r + || expected_ch.beta_m != ch.beta_m + || expected_ch.gamma != ch.gamma + { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS challenges_public mismatch", + idx + ))); + } + + // Public initial sum T for CCS sumcheck (engine-selected). + let claimed_initial = match &mode { + FoldingMode::Optimized => crate::optimized_engine::claimed_initial_sum_from_inputs(&s, &ch, &accumulator), + #[cfg(feature = "paper-exact")] + FoldingMode::PaperExact => { + crate::paper_exact_engine::claimed_initial_sum_from_inputs(&s, &ch, &accumulator) + } + #[cfg(feature = "paper-exact")] + FoldingMode::OptimizedWithCrosscheck(_) => { + crate::optimized_engine::claimed_initial_sum_from_inputs(&s, &ch, &accumulator) + } + }; + if let Some(x) = step_proof.fold.ccs_proof.sc_initial_sum { + if x != claimed_initial { + return Err(PiCcsError::SumcheckError( + "initial sum mismatch: proof claims different value than public T".into(), + )); + } + } + tr.append_fields(b"sumcheck/initial_sum", &claimed_initial.as_coeffs()); + + // Route A memory checks use a separate transcript-derived cycle point `r_cycle` + // to form χ_{r_cycle}(t) weights inside their sum-check polynomials. + let r_cycle: Vec = + ts::sample_ext_point(tr, b"route_a/r_cycle", b"route_a/cycle/0", b"route_a/cycle/1", ell_n); + + let shout_pre = crate::memory_sidecar::memory::verify_shout_addr_pre_time(tr, step, &step_proof.mem, step_idx)?; + let twist_pre = crate::memory_sidecar::memory::verify_twist_addr_pre_time(tr, step, &step_proof.mem)?; + let wb_enabled = crate::memory_sidecar::memory::wb_wp_required_for_step_instance(step); + let wp_enabled = crate::memory_sidecar::memory::wb_wp_required_for_step_instance(step); + let decode_stage_enabled = crate::memory_sidecar::memory::decode_stage_required_for_step_instance(step); + let width_stage_enabled = crate::memory_sidecar::memory::width_stage_required_for_step_instance(step); + let control_stage_enabled = crate::memory_sidecar::memory::control_stage_required_for_step_instance(step); + let crate::memory_sidecar::route_a_time::RouteABatchedTimeVerifyOutput { r_time, final_values } = + crate::memory_sidecar::route_a_time::verify_route_a_batched_time( + tr, + step_idx, + ell_n, + d_sc, + claimed_initial, + step, + &step_proof.batched_time, + wb_enabled, + wp_enabled, + decode_stage_enabled, + width_stage_enabled, + control_stage_enabled, + ob_inc_total_degree_bound, + )?; + + // CCS proof structure consistency with batched time proof. + let want_rounds_total = ell_n + ell_d; + if step_proof.fold.ccs_proof.sumcheck_rounds.len() != want_rounds_total { + return Err(PiCcsError::InvalidInput(format!( + "step {}: CCS sumcheck_rounds.len()={}, expected {}", + idx, + step_proof.fold.ccs_proof.sumcheck_rounds.len(), + want_rounds_total + ))); + } + if step_proof.fold.ccs_proof.sumcheck_challenges.len() != want_rounds_total { + return Err(PiCcsError::InvalidInput(format!( + "step {}: CCS sumcheck_challenges.len()={}, expected {}", + idx, + step_proof.fold.ccs_proof.sumcheck_challenges.len(), + want_rounds_total + ))); + } + for (round_idx, (a, b)) in step_proof + .fold + .ccs_proof + .sumcheck_rounds + .iter() + .take(ell_n) + .zip(step_proof.batched_time.round_polys[0].iter()) + .enumerate() + { + if a != b { + return Err(PiCcsError::ProtocolError(format!( + "step {}: CCS time round poly mismatch at round {}", + idx, round_idx + ))); + } + } + + if step_proof.fold.ccs_proof.sumcheck_challenges[..ell_n] != r_time { + return Err(PiCcsError::ProtocolError(format!( + "step {}: CCS time challenges mismatch with r_time", + idx + ))); + } + + let expected_k = accumulator.len() + 1; + if step_proof.fold.ccs_out.len() != expected_k { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS returned {} outputs; expected k={}", + idx, + step_proof.fold.ccs_out.len(), + expected_k + ))); + } + if step_proof.fold.ccs_out.is_empty() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS produced empty ccs_out", + idx + ))); + } + if step_proof.fold.ccs_out[0].r != r_time { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output r != r_time (Route A requires shared r)", + idx + ))); + } + + // Bind Π_CCS outputs to the public MCS instance and carried ME inputs. + // + // - Commitments must match (Π_CCS does not change commitments). + // - `X` must match the digit-decomposition of public `x` for the MCS output. + // - `X` must match the carried ME inputs for subsequent outputs. + { + let out0 = &step_proof.fold.ccs_out[0]; + if out0.c != mcs_inst.c { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[0].c does not match mcs_inst.c", + idx + ))); + } + if out0.m_in != mcs_inst.m_in { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[0].m_in={}, expected {}", + idx, out0.m_in, mcs_inst.m_in + ))); + } + if out0.X.rows() != D || out0.X.cols() != mcs_inst.m_in { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[0].X has shape {}×{}, expected {}×{}", + idx, + out0.X.rows(), + out0.X.cols(), + D, + mcs_inst.m_in + ))); + } + + for (i, inp) in accumulator.iter().enumerate() { + let out = &step_proof.fold.ccs_out[i + 1]; + if out.c != inp.c { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[{}].c does not match accumulator[{}].c", + idx, + i + 1, + i + ))); + } + if out.m_in != inp.m_in { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[{}].m_in={}, expected {}", + idx, + i + 1, + out.m_in, + inp.m_in + ))); + } + if out.X != inp.X { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[{}].X does not match accumulator[{}].X", + idx, + i + 1, + i + ))); + } + } + } + + // Finish CCS Ajtai rounds alone (continuing transcript state after batched rounds). + let ajtai_rounds = &step_proof.fold.ccs_proof.sumcheck_rounds[ell_n..]; + let (ajtai_chals, running_sum, ok) = + verify_sumcheck_rounds_ds(tr, b"ccs/ajtai", step_idx, d_sc, final_values[0], ajtai_rounds); + if !ok { + return Err(PiCcsError::SumcheckError("Π_CCS Ajtai rounds invalid".into())); + } + + // Verify stored sumcheck challenges/final match transcript-derived values. + let mut r_all = r_time.clone(); + r_all.extend_from_slice(&ajtai_chals); + if r_all != step_proof.fold.ccs_proof.sumcheck_challenges { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS sumcheck challenges mismatch", + idx + ))); + } + if running_sum != step_proof.fold.ccs_proof.sumcheck_final { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS sumcheck_final mismatch", + idx + ))); + } + + // Validate ME input r length (required by RHS assembly if k>1). + for (i, me) in accumulator.iter().enumerate() { + if me.r.len() != ell_n { + return Err(PiCcsError::InvalidInput(format!( + "step {}: ME input r length mismatch at accumulator #{}: expected {}, got {}", + idx, + i, + ell_n, + me.r.len() + ))); + } + } + + if step_proof.fold.ccs_proof.variant != crate::optimized_engine::PiCcsProofVariant::SplitNcV1 { + return Err(PiCcsError::ProtocolError("unsupported Π_CCS proof variant".into())); + } + + // FE-only terminal identity. + let rhs_fe = crate::paper_exact_engine::rhs_terminal_identity_fe_paper_exact( + &s, + params, + &ch, + &r_time, + &ajtai_chals, + &step_proof.fold.ccs_out, + accumulator.first().map(|mi| mi.r.as_slice()), + ); + if running_sum != rhs_fe { + return Err(PiCcsError::SumcheckError( + "Π_CCS FE-only terminal identity check failed".into(), + )); + } + + // NC-only sumcheck + terminal identity. + if step_proof.fold.ccs_proof.sumcheck_rounds_nc.is_empty() { + return Err(PiCcsError::InvalidInput( + "Π_CCS SplitNcV1 requires non-empty sumcheck_rounds_nc".into(), + )); + } + if let Some(x) = step_proof.fold.ccs_proof.sc_initial_sum_nc { + if x != K::ZERO { + return Err(PiCcsError::InvalidInput( + "Π_CCS SplitNcV1 requires sc_initial_sum_nc == 0".into(), + )); + } + } + let want_nc_rounds_total = ell_m + .checked_add(ell_d) + .ok_or_else(|| PiCcsError::ProtocolError("ell_m + ell_d overflow".into()))?; + if step_proof.fold.ccs_proof.sumcheck_rounds_nc.len() != want_nc_rounds_total { + return Err(PiCcsError::InvalidInput(format!( + "step {}: Π_CCS NC sumcheck_rounds_nc.len()={}, expected {}", + idx, + step_proof.fold.ccs_proof.sumcheck_rounds_nc.len(), + want_nc_rounds_total + ))); + } + if step_proof.fold.ccs_proof.sumcheck_challenges_nc.len() != want_nc_rounds_total { + return Err(PiCcsError::InvalidInput(format!( + "step {}: Π_CCS NC sumcheck_challenges_nc.len()={}, expected {}", + idx, + step_proof.fold.ccs_proof.sumcheck_challenges_nc.len(), + want_nc_rounds_total + ))); + } + + let (nc_chals, running_sum_nc, ok_nc) = verify_sumcheck_rounds_ds( + tr, + b"ccs/nc", + step_idx, + d_sc, + K::ZERO, + &step_proof.fold.ccs_proof.sumcheck_rounds_nc, + ); + if !ok_nc { + return Err(PiCcsError::SumcheckError("Π_CCS NC rounds invalid".into())); + } + + if nc_chals != step_proof.fold.ccs_proof.sumcheck_challenges_nc { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS NC sumcheck challenges mismatch", + idx + ))); + } + if running_sum_nc != step_proof.fold.ccs_proof.sumcheck_final_nc { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS sumcheck_final_nc mismatch", + idx + ))); + } + + let (s_col_prime, alpha_prime_nc) = nc_chals.split_at(ell_m); + let d_pad = 1usize + .checked_shl(ell_d as u32) + .ok_or_else(|| PiCcsError::ProtocolError("2^ell_d overflow".into()))?; + for (out_idx, out) in step_proof.fold.ccs_out.iter().enumerate() { + if out.r != r_time { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[{out_idx}] r != r_time", + idx + ))); + } + if out.s_col.as_slice() != s_col_prime { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[{out_idx}] s_col mismatch", + idx + ))); + } + if out.y_zcol.len() != d_pad { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[{out_idx}] y_zcol.len()={}, expected {}", + idx, + out.y_zcol.len(), + d_pad + ))); + } + } + + let rhs_nc = crate::paper_exact_engine::rhs_terminal_identity_nc_paper_exact( + params, + &ch, + s_col_prime, + alpha_prime_nc, + &step_proof.fold.ccs_out, + ); + if running_sum_nc != rhs_nc { + return Err(PiCcsError::SumcheckError( + "Π_CCS NC terminal identity check failed".into(), + )); + } + + let observed_digest = tr.digest32(); + if observed_digest != step_proof.fold.ccs_proof.header_digest.as_slice() { + return Err(PiCcsError::ProtocolError("Π_CCS header digest mismatch".into())); + } + let expected_digest: [u8; 32] = step_proof + .fold + .ccs_proof + .header_digest + .as_slice() + .try_into() + .map_err(|_| PiCcsError::ProtocolError("Π_CCS header digest must be 32 bytes".into()))?; + for (out_idx, out) in step_proof.fold.ccs_out.iter().enumerate() { + if out.fold_digest != expected_digest { + return Err(PiCcsError::ProtocolError(format!( + "step {}: Π_CCS output[{out_idx}] fold_digest mismatch", + idx + ))); + } + } + + // Verify mem proofs (shared CPU bus only). + let prev_step = (idx > 0).then(|| &steps[idx - 1]); + let mem_out = crate::memory_sidecar::memory::verify_route_a_memory_step( + tr, + &cpu_bus, + s.m, + s.t(), + step, + prev_step, + &step_proof.fold.ccs_out[0], + &r_time, + &r_cycle, + &final_values, + &step_proof.batched_time.claimed_sums, + 1, // claim 0 is CCS/time + &step_proof.mem, + &shout_pre, + &twist_pre, + step_idx, + )?; + + let expected_consumed = if include_ob { + final_values + .len() + .checked_sub(1) + .ok_or_else(|| PiCcsError::ProtocolError("missing output binding claim".into()))? + } else { + final_values.len() + }; + if mem_out.claim_idx_end != expected_consumed { + return Err(PiCcsError::ProtocolError(format!( + "step {}: batched claim index mismatch (consumed {}, expected {})", + idx, mem_out.claim_idx_end, expected_consumed + ))); + } + + if include_ob { + let cfg = + ob_cfg.ok_or_else(|| PiCcsError::InvalidInput("output binding enabled but config missing".into()))?; + let ob_state = ob_state + .take() + .ok_or_else(|| PiCcsError::ProtocolError("output sumcheck state missing".into()))?; + + let inc_idx = final_values + .len() + .checked_sub(1) + .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total claim".into()))?; + if step_proof.batched_time.labels.get(inc_idx).copied() != Some(crate::output_binding::OB_INC_TOTAL_LABEL) { + return Err(PiCcsError::ProtocolError("output binding claim not last".into())); + } + + let inc_total_claim = *step_proof + .batched_time + .claimed_sums + .get(inc_idx) + .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total claimed_sum".into()))?; + let inc_total_final = *final_values + .get(inc_idx) + .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total final_value".into()))?; + + let twist_open = mem_out + .twist_time_openings + .get(cfg.mem_idx) + .ok_or_else(|| PiCcsError::ProtocolError("missing twist_time_openings for mem_idx".into()))?; + let inc_terminal = crate::output_binding::inc_terminal_from_time_openings(twist_open, &ob_state.r_prime) + .map_err(|e| PiCcsError::ProtocolError(format!("inc_total terminal mismatch: {e:?}")))?; + if inc_total_final != inc_terminal { + return Err(PiCcsError::ProtocolError("inc_total terminal mismatch".into())); + } + + let mem_inst = step + .mem_insts + .get(cfg.mem_idx) + .ok_or_else(|| PiCcsError::InvalidInput("output binding mem_idx out of range".into()))?; + let expected_k = 1usize + .checked_shl(cfg.num_bits as u32) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^num_bits overflow".into()))?; + if mem_inst.k != expected_k { + return Err(PiCcsError::InvalidInput(format!( + "output binding: cfg.num_bits implies k={}, but mem_inst.k={}", + expected_k, mem_inst.k + ))); + } + let ell_addr = mem_inst.twist_layout().lanes[0].ell_addr; + if ell_addr != cfg.num_bits { + return Err(PiCcsError::InvalidInput(format!( + "output binding: cfg.num_bits={}, but twist_layout.ell_addr={}", + cfg.num_bits, ell_addr + ))); + } + let val_init = crate::output_binding::val_init_from_mem_init(&mem_inst.init, mem_inst.k, &ob_state.r_prime) + .map_err(|e| PiCcsError::ProtocolError(format!("MemInit eval failed: {e:?}")))?; + + let val_final_at_r_prime = val_init + inc_total_claim; + let expected_out = ob_state.eq_eval * ob_state.io_mask_eval * (val_final_at_r_prime - ob_state.val_io_eval); + if expected_out != ob_state.output_final { + return Err(PiCcsError::ProtocolError("output binding final check failed".into())); + } + } + + validate_me_batch_invariants(&step_proof.fold.ccs_out, "verify step ccs outputs")?; + verify_rlc_dec_lane( + RlcLane::Main, + tr, + params, + &s, + &ring, + ell_d, + mixers, + step_idx, + &step_proof.fold.ccs_out, + &step_proof.fold.rlc_rhos, + &step_proof.fold.rlc_parent, + &step_proof.fold.dec_children, + )?; + + accumulator = step_proof.fold.dec_children.clone(); + + // Phase 2: Verify folding lanes for ME claims evaluated at r_val. + if step_proof.mem.val_me_claims.is_empty() { + if !step_proof.val_fold.is_empty() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: unexpected val_fold proof(s) (no r_val ME claims)", + idx + ))); + } + } else { + tr.append_message(b"fold/val_lane_start", &(step_idx as u64).to_le_bytes()); + let expected = 1usize + usize::from(has_prev); + if step_proof.mem.val_me_claims.len() != expected { + return Err(PiCcsError::ProtocolError(format!( + "step {}: val_me_claims count mismatch in shared-bus mode (have {}, expected {})", + idx, + step_proof.mem.val_me_claims.len(), + expected + ))); + } + if step_proof.val_fold.len() != expected { + return Err(PiCcsError::ProtocolError(format!( + "step {}: val_fold count mismatch in shared-bus mode (have {}, expected {})", + idx, + step_proof.val_fold.len(), + expected + ))); + } + + for (claim_idx, (me, proof)) in step_proof + .mem + .val_me_claims + .iter() + .zip(step_proof.val_fold.iter()) + .enumerate() + { + let ctx = match claim_idx { + 0 => "cpu", + 1 => "cpu_prev", + _ => { + return Err(PiCcsError::ProtocolError( + "unexpected extra r_val ME claim in shared-bus mode".into(), + )); + } + }; + tr.append_message(b"fold/val_lane_claim_idx", &(claim_idx as u64).to_le_bytes()); + tr.append_message(b"fold/val_lane_claim_ctx", ctx.as_bytes()); + verify_rlc_dec_lane( + RlcLane::Val, + tr, + params, + &s, + &ring, + ell_d, + mixers, + step_idx, + core::slice::from_ref(me), + &proof.rlc_rhos, + &proof.rlc_parent, + &proof.dec_children, + ) + .map_err(|e| { + PiCcsError::ProtocolError(format!( + "step {} val_fold(shared) claim {} ({ctx}) verify failed: {e:?}", + idx, claim_idx + )) + })?; + val_lane_obligations.extend_from_slice(&proof.dec_children); + } + } + + if step_proof.mem.wb_me_claims.is_empty() { + if !step_proof.wb_fold.is_empty() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: unexpected wb_fold proof(s) (no WB ME claims)", + idx + ))); + } + } else { + if step_proof.wb_fold.len() != step_proof.mem.wb_me_claims.len() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: wb_fold count mismatch (have {}, expected {})", + idx, + step_proof.wb_fold.len(), + step_proof.mem.wb_me_claims.len() + ))); + } + tr.append_message(b"fold/wb_lane_start", &(step_idx as u64).to_le_bytes()); + for (claim_idx, (me, proof)) in step_proof + .mem + .wb_me_claims + .iter() + .zip(step_proof.wb_fold.iter()) + .enumerate() + { + tr.append_message(b"fold/wb_lane_claim_idx", &(claim_idx as u64).to_le_bytes()); + verify_rlc_dec_lane( + RlcLane::Val, + tr, + params, + &s, + &ring, + ell_d, + mixers, + step_idx, + core::slice::from_ref(me), + &proof.rlc_rhos, + &proof.rlc_parent, + &proof.dec_children, + ) + .map_err(|e| { + PiCcsError::ProtocolError(format!("step {} wb_fold claim {} verify failed: {e:?}", idx, claim_idx)) + })?; + val_lane_obligations.extend_from_slice(&proof.dec_children); + } + } + + if step_proof.mem.wp_me_claims.is_empty() { + if !step_proof.wp_fold.is_empty() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: unexpected wp_fold proof(s) (no WP ME claims)", + idx + ))); + } + } else { + if step_proof.wp_fold.len() != step_proof.mem.wp_me_claims.len() { + return Err(PiCcsError::ProtocolError(format!( + "step {}: wp_fold count mismatch (have {}, expected {})", + idx, + step_proof.wp_fold.len(), + step_proof.mem.wp_me_claims.len() + ))); + } + tr.append_message(b"fold/wp_lane_start", &(step_idx as u64).to_le_bytes()); + for (claim_idx, (me, proof)) in step_proof + .mem + .wp_me_claims + .iter() + .zip(step_proof.wp_fold.iter()) + .enumerate() + { + tr.append_message(b"fold/wp_lane_claim_idx", &(claim_idx as u64).to_le_bytes()); + verify_rlc_dec_lane( + RlcLane::Val, + tr, + params, + &s, + &ring, + ell_d, + mixers, + step_idx, + core::slice::from_ref(me), + &proof.rlc_rhos, + &proof.rlc_parent, + &proof.dec_children, + ) + .map_err(|e| { + PiCcsError::ProtocolError(format!("step {} wp_fold claim {} verify failed: {e:?}", idx, claim_idx)) + })?; + val_lane_obligations.extend_from_slice(&proof.dec_children); + } + } + + tr.append_message(b"fold/step_done", &(step_idx as u64).to_le_bytes()); + } + + Ok(ShardFoldOutputs { + obligations: ShardObligations { + main: accumulator, + val: val_lane_obligations, + }, + }) +} + +pub fn fold_shard_verify( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + fold_shard_verify_impl(mode, tr, params, s_me, steps, 0, acc_init, proof, mixers, None, None) +} + +/// Same as `fold_shard_verify`, but offsets the per-step transcript index by `step_idx_offset`. +pub fn fold_shard_verify_with_step_offset( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + step_idx_offset: usize, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + fold_shard_verify_impl( + mode, + tr, + params, + s_me, + steps, + step_idx_offset, + acc_init, + proof, + mixers, + None, + None, + ) +} + +pub fn fold_shard_verify_with_step_linking( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + step_linking: &StepLinkingConfig, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + check_step_linking(steps, step_linking)?; + fold_shard_verify(mode, tr, params, s_me, steps, acc_init, proof, mixers) +} + +pub fn fold_shard_verify_with_output_binding( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + ob_cfg: &crate::output_binding::OutputBindingConfig, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + fold_shard_verify_impl( + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + proof, + mixers, + Some(ob_cfg), + None, + ) +} + +pub(crate) fn fold_shard_verify_with_context( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + prover_ctx: &ShardProverContext, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + fold_shard_verify_impl( + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + proof, + mixers, + None, + Some(prover_ctx), + ) +} + +pub(crate) fn fold_shard_verify_with_step_linking_with_context( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + step_linking: &StepLinkingConfig, + prover_ctx: &ShardProverContext, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + check_step_linking(steps, step_linking)?; + fold_shard_verify_with_context(mode, tr, params, s_me, steps, acc_init, proof, mixers, prover_ctx) +} + +pub(crate) fn fold_shard_verify_with_output_binding_with_context( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + ob_cfg: &crate::output_binding::OutputBindingConfig, + prover_ctx: &ShardProverContext, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + fold_shard_verify_impl( + mode, + tr, + params, + s_me, + steps, + 0, + acc_init, + proof, + mixers, + Some(ob_cfg), + Some(prover_ctx), + ) +} + +pub(crate) fn fold_shard_verify_with_output_binding_and_step_linking_with_context( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + ob_cfg: &crate::output_binding::OutputBindingConfig, + step_linking: &StepLinkingConfig, + prover_ctx: &ShardProverContext, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + check_step_linking(steps, step_linking)?; + fold_shard_verify_with_output_binding_with_context( + mode, tr, params, s_me, steps, acc_init, proof, mixers, ob_cfg, prover_ctx, + ) +} + +pub fn fold_shard_verify_with_output_binding_and_step_linking( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + ob_cfg: &crate::output_binding::OutputBindingConfig, + step_linking: &StepLinkingConfig, +) -> Result, PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, +{ + check_step_linking(steps, step_linking)?; + fold_shard_verify_with_output_binding(mode, tr, params, s_me, steps, acc_init, proof, mixers, ob_cfg) +} + +pub fn fold_shard_verify_and_finalize( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + finalizer: &mut Fin, +) -> Result<(), PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, + Fin: ObligationFinalizer, +{ + let outputs = fold_shard_verify(mode, tr, params, s_me, steps, acc_init, proof, mixers)?; + let report = finalizer.finalize(&outputs.obligations)?; + outputs + .obligations + .require_all_finalized(report.did_finalize_main, report.did_finalize_val)?; + Ok(()) +} + +pub fn fold_shard_verify_and_finalize_with_step_linking( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + step_linking: &StepLinkingConfig, + finalizer: &mut Fin, +) -> Result<(), PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, + Fin: ObligationFinalizer, +{ + check_step_linking(steps, step_linking)?; + fold_shard_verify_and_finalize(mode, tr, params, s_me, steps, acc_init, proof, mixers, finalizer) +} + +pub fn fold_shard_verify_and_finalize_with_output_binding( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + ob_cfg: &crate::output_binding::OutputBindingConfig, + finalizer: &mut Fin, +) -> Result<(), PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, + Fin: ObligationFinalizer, +{ + let outputs = + fold_shard_verify_with_output_binding(mode, tr, params, s_me, steps, acc_init, proof, mixers, ob_cfg)?; + let report = finalizer.finalize(&outputs.obligations)?; + outputs + .obligations + .require_all_finalized(report.did_finalize_main, report.did_finalize_val)?; + Ok(()) +} + +pub fn fold_shard_verify_and_finalize_with_output_binding_and_step_linking( + mode: FoldingMode, + tr: &mut Poseidon2Transcript, + params: &NeoParams, + s_me: &CcsStructure, + steps: &[StepInstanceBundle], + acc_init: &[MeInstance], + proof: &ShardProof, + mixers: CommitMixers, + ob_cfg: &crate::output_binding::OutputBindingConfig, + step_linking: &StepLinkingConfig, + finalizer: &mut Fin, +) -> Result<(), PiCcsError> +where + MR: Fn(&[Mat], &[Cmt]) -> Cmt + Clone + Copy, + MB: Fn(&[Cmt], u32) -> Cmt + Clone + Copy, + Fin: ObligationFinalizer, +{ + check_step_linking(steps, step_linking)?; + fold_shard_verify_and_finalize_with_output_binding( + mode, tr, params, s_me, steps, acc_init, proof, mixers, ob_cfg, finalizer, + ) +} diff --git a/crates/neo-fold/src/shard_proof_types.rs b/crates/neo-fold/src/shard_proof_types.rs index 51534325..261b7cb8 100644 --- a/crates/neo-fold/src/shard_proof_types.rs +++ b/crates/neo-fold/src/shard_proof_types.rs @@ -133,11 +133,14 @@ pub enum MemOrLutProof { #[derive(Clone, Debug)] pub struct MemSidecarProof { - /// CPU ME claims evaluated at `r_val` (Twist val-eval terminal point). + /// ME claims evaluated at `r_val` (Twist val-eval terminal point). /// - /// In shared-CPU-bus-only mode, Twist reads val-lane openings from these claims - /// (current step + optional previous step for rollover). - pub cpu_me_claims_val: Vec>, + /// Shared-bus mode only: these are CPU ME openings at `r_val` that include appended bus openings. + pub val_me_claims: Vec>, + /// CPU ME openings at `r_time` used to bind WB booleanity terminals to committed trace columns. + pub wb_me_claims: Vec>, + /// CPU ME openings at `r_time` used to bind WP quiescence terminals to committed trace columns. + pub wp_me_claims: Vec>, /// Route A Shout address pre-time proofs batched across all Shout instances in the step. pub shout_addr_pre: ShoutAddrPreProof, pub proofs: Vec, @@ -175,8 +178,14 @@ pub struct StepProof { pub fold: FoldStep, pub mem: MemSidecarProof, pub batched_time: BatchedTimeProof, - /// Optional second folding lane for Twist val-eval ME claims at `r_val`. - pub val_fold: Option, + /// Optional folding lane(s) for ME claims evaluated at `r_val`. + /// + /// Each proof is an independent Π_RLC→Π_DEC lane (k=1 in current usage). + pub val_fold: Vec, + /// Reserved WB folding lane(s) for staged booleanity claims. + pub wb_fold: Vec, + /// Reserved WP folding lane(s) for staged quiescence claims. + pub wp_fold: Vec, } #[derive(Clone, Debug)] @@ -216,8 +225,14 @@ impl ShardProof { let mut val = Vec::new(); for step in &self.steps { - if let Some(val_fold) = &step.val_fold { - val.extend_from_slice(&val_fold.dec_children); + for p in &step.val_fold { + val.extend_from_slice(&p.dec_children); + } + for p in &step.wb_fold { + val.extend_from_slice(&p.dec_children); + } + for p in &step.wp_fold { + val.extend_from_slice(&p.dec_children); } } diff --git a/crates/neo-fold/src/test_export.rs b/crates/neo-fold/src/test_export.rs index e7e564fe..6e03676a 100644 --- a/crates/neo-fold/src/test_export.rs +++ b/crates/neo-fold/src/test_export.rs @@ -199,8 +199,8 @@ fn build_step_ccs( matrix_b: &SparseMatrix, matrix_c: &SparseMatrix, ) -> CcsStructure { - // Pad exported R1CS to square n×n (needed for Ajtai/NC identity-first semantics). - let n = num_constraints.max(num_variables); + let n = num_constraints; + let m = num_variables; let to_triplets = |m: &SparseMatrix| -> Vec<(usize, usize, F)> { m.entries @@ -209,11 +209,11 @@ fn build_step_ccs( .collect() }; - let a = CcsMatrix::Csc(CscMat::from_triplets(to_triplets(matrix_a), n, n)); - let b = CcsMatrix::Csc(CscMat::from_triplets(to_triplets(matrix_b), n, n)); - let c = CcsMatrix::Csc(CscMat::from_triplets(to_triplets(matrix_c), n, n)); + let a = CcsMatrix::Csc(CscMat::from_triplets(to_triplets(matrix_a), n, m)); + let b = CcsMatrix::Csc(CscMat::from_triplets(to_triplets(matrix_b), n, m)); + let c = CcsMatrix::Csc(CscMat::from_triplets(to_triplets(matrix_c), n, m)); - // R1CS → CCS embedding with identity-first form: M_0 = I_n, M_1=A, M_2=B, M_3=C. + // R1CS → CCS embedding: M_0=A, M_1=B, M_2=C. let f_base = SparsePoly::new( 3, vec![ @@ -227,9 +227,7 @@ fn build_step_ccs( }, // -X3 ], ); - let f = f_base.insert_var_at_front(); - - CcsStructure::new_sparse(vec![CcsMatrix::Identity { n }, a, b, c], f).expect("valid R1CS→CCS structure") + CcsStructure::new_sparse(vec![a, b, c], f_base).expect("valid R1CS→CCS structure") } fn pad_witness_to_m(mut z: Vec, m_target: usize) -> Vec { @@ -592,8 +590,14 @@ pub fn estimate_proof(proof: &crate::shard::ShardProof) -> TestExportProofEstima for step in &proof.steps { fold_lane_commitments = fold_lane_commitments.saturating_add(step.fold.ccs_out.len() + step.fold.dec_children.len() + 1); - mem_cpu_val_claim_commitments = mem_cpu_val_claim_commitments.saturating_add(step.mem.cpu_me_claims_val.len()); - if let Some(val) = &step.val_fold { + mem_cpu_val_claim_commitments = mem_cpu_val_claim_commitments.saturating_add(step.mem.val_me_claims.len()); + for val in &step.val_fold { + val_lane_commitments = val_lane_commitments.saturating_add(val.dec_children.len() + 1); + } + for val in &step.wb_fold { + val_lane_commitments = val_lane_commitments.saturating_add(val.dec_children.len() + 1); + } + for val in &step.wp_fold { val_lane_commitments = val_lane_commitments.saturating_add(val.dec_children.len() + 1); } } diff --git a/crates/neo-fold/tests/README.md b/crates/neo-fold/tests/README.md new file mode 100644 index 00000000..fe7d57de --- /dev/null +++ b/crates/neo-fold/tests/README.md @@ -0,0 +1,50 @@ +# neo-fold integration tests + +This directory is organized by suite to make ownership and intent explicit. + +## Top-level test crates + +Each file at this level is a thin entrypoint that mounts one suite module: + +- `trace_shout.rs` +- `trace_twist.rs` +- `shared_bus.rs` +- `session.rs` +- `rv32m.rs` +- `integration.rs` +- `perf.rs` +- `vm.rs` +- `redteam.rs` +- `redteam_riscv.rs` +- `regression.rs` + +## Suite layout + +- `suites/trace_shout/`: shout-sidecar e2e + red-team coverage. +- `suites/trace_twist/`: twist-sidecar e2e + linkage hardening tests. +- `suites/shared_bus/`: shared CPU bus coverage and attacks. +- `suites/session/`: folding session/unit behavior. +- `suites/rv32m/`: RV32M-specific tests. +- `suites/integration/`: end-to-end/proving pipeline integration. +- `suites/perf/`: perf and scaling tests (typically `#[ignore]`). +- `suites/vm/`: VM extraction and wasm-demo tests. +- `suites/redteam/`: cross-cutting adversarial tests. +- `suites/redteam_riscv/`: RISC-V-specific adversarial tests. +- `suites/regression/`: regression tests. + +## Shared helpers + +- `common/fixtures.rs` +- `common/fib_twist_shout_vm.rs` +- `common/riscv_shout_event_table_packed.rs` +- `common/setup.rs` + +## Suggested run commands + +- `cargo test -p neo-fold --test trace_shout` +- `cargo test -p neo-fold --test trace_twist` +- `cargo test -p neo-fold --test shared_bus` +- `cargo test -p neo-fold --test session` +- `cargo test -p neo-fold --test integration` +- `cargo test -p neo-fold --test vm` +- `cargo test -p neo-fold --tests --no-run` diff --git a/crates/neo-fold/tests/common/fixtures.rs b/crates/neo-fold/tests/common/fixtures.rs index e7aa9bbf..f010322f 100644 --- a/crates/neo-fold/tests/common/fixtures.rs +++ b/crates/neo-fold/tests/common/fixtures.rs @@ -217,7 +217,7 @@ pub struct ShardFixture { fn build_twist_shout_2step_fixture_inner(seed: u64, bad_lookup_step1: bool) -> ShardFixture { // Keep CCS small but ensure it can fit the shared CPU bus tail. - // Must be square (n==m) due to identity-first ME semantics. + // Keep this fixture square for simplicity; rectangular CCS is supported. let n = 32usize; let ccs = create_identity_ccs(n); let mut params = NeoParams::goldilocks_auto_r1cs_ccs(n).expect("params"); @@ -295,6 +295,7 @@ fn build_twist_shout_2step_fixture_inner(seed: u64, bad_lookup_step1: bool) -> S let lut_ell = lut_table.n_side.trailing_zeros() as usize; let mem_inst0 = neo_memory::witness::MemInstance:: { + mem_id: 0, comms: Vec::new(), k: mem_layout.k, d: mem_layout.d, @@ -306,6 +307,7 @@ fn build_twist_shout_2step_fixture_inner(seed: u64, bad_lookup_step1: bool) -> S }; let mem_wit0 = neo_memory::witness::MemWitness { mats: Vec::new() }; let lut_inst0 = neo_memory::witness::LutInstance:: { + table_id: lut_table.table_id, comms: Vec::new(), k: lut_table.k, d: lut_table.d, @@ -315,10 +317,13 @@ fn build_twist_shout_2step_fixture_inner(seed: u64, bad_lookup_step1: bool) -> S ell: lut_ell, table_spec: None, table: lut_table.content.clone(), + addr_group: None, + selector_group: None, }; let lut_wit0 = neo_memory::witness::LutWitness { mats: Vec::new() }; let mem_inst1 = neo_memory::witness::MemInstance:: { + mem_id: 0, comms: Vec::new(), k: mem_layout.k, d: mem_layout.d, @@ -330,6 +335,7 @@ fn build_twist_shout_2step_fixture_inner(seed: u64, bad_lookup_step1: bool) -> S }; let mem_wit1 = neo_memory::witness::MemWitness { mats: Vec::new() }; let lut_inst1 = neo_memory::witness::LutInstance:: { + table_id: lut_table.table_id, comms: Vec::new(), k: lut_table.k, d: lut_table.d, @@ -339,6 +345,8 @@ fn build_twist_shout_2step_fixture_inner(seed: u64, bad_lookup_step1: bool) -> S ell: lut_ell, table_spec: None, table: lut_table.content.clone(), + addr_group: None, + selector_group: None, }; let lut_wit1 = neo_memory::witness::LutWitness { mats: Vec::new() }; diff --git a/crates/neo-fold/tests/common/mod.rs b/crates/neo-fold/tests/common/mod.rs new file mode 100644 index 00000000..52c9f30f --- /dev/null +++ b/crates/neo-fold/tests/common/mod.rs @@ -0,0 +1,4 @@ +pub mod fib_twist_shout_vm; +pub mod fixtures; +pub mod riscv_shout_event_table_packed; +pub mod setup; diff --git a/crates/neo-fold/tests/common/riscv_shout_event_table_packed.rs b/crates/neo-fold/tests/common/riscv_shout_event_table_packed.rs new file mode 100644 index 00000000..dd79887d --- /dev/null +++ b/crates/neo-fold/tests/common/riscv_shout_event_table_packed.rs @@ -0,0 +1,714 @@ +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::exec_table::Rv32ShoutEventRow; +use neo_memory::riscv::lookups::RiscvOpcode; +use p3_field::{Field, PrimeCharacteristicRing}; + +pub fn ell_n_from_ccs_n(n: usize) -> usize { + let n_pad = n.next_power_of_two().max(2); + n_pad.trailing_zeros() as usize +} + +pub fn rv32_packed_base_d(op: RiscvOpcode) -> Result { + Ok(match op { + RiscvOpcode::And | RiscvOpcode::Andn | RiscvOpcode::Or | RiscvOpcode::Xor => 34usize, + RiscvOpcode::Add | RiscvOpcode::Sub => 3usize, + RiscvOpcode::Eq | RiscvOpcode::Neq => 35usize, + RiscvOpcode::Slt => 37usize, + RiscvOpcode::Sll | RiscvOpcode::Srl | RiscvOpcode::Sra => 38usize, + RiscvOpcode::Sltu => 35usize, + RiscvOpcode::Mul => 34usize, + RiscvOpcode::Mulh => 38usize, + RiscvOpcode::Mulhu => 34usize, + RiscvOpcode::Mulhsu => 37usize, + RiscvOpcode::Div | RiscvOpcode::Rem => 43usize, + RiscvOpcode::Divu | RiscvOpcode::Remu => 38usize, + _ => { + return Err(format!( + "event-table packed: unsupported opcode={op:?} (no packed layout)" + )); + } + }) +} + +fn mulh_hi_signed(lhs: u32, rhs: u32) -> u32 { + let a = lhs as i32 as i64; + let b = rhs as i32 as i64; + let p = a * b; + (p >> 32) as i32 as u32 +} + +fn mulhsu_hi_signed(lhs: u32, rhs: u32) -> u32 { + let a = lhs as i32 as i64; + let b = rhs as i64; + let p = a * b; + (p >> 32) as i32 as u32 +} + +fn div_signed(lhs: u32, rhs: u32) -> u32 { + let lhs_i = lhs as i32; + let rhs_i = rhs as i32; + if rhs_i == 0 { + return u32::MAX; + } + if lhs_i == i32::MIN && rhs_i == -1 { + return lhs; // overflow case: quotient = MIN_INT + } + (lhs_i / rhs_i) as u32 +} + +fn rem_signed(lhs: u32, rhs: u32) -> u32 { + let lhs_i = lhs as i32; + let rhs_i = rhs as i32; + if rhs_i == 0 { + return lhs; + } + if lhs_i == i32::MIN && rhs_i == -1 { + return 0; // overflow case: remainder = 0 + } + (lhs_i % rhs_i) as u32 +} + +fn divu(lhs: u32, rhs: u32) -> u32 { + if rhs == 0 { + return u32::MAX; + } + (lhs as u64 / rhs as u64) as u32 +} + +fn remu(lhs: u32, rhs: u32) -> u32 { + if rhs == 0 { + return lhs; + } + ((lhs as u64) % (rhs as u64)) as u32 +} + +pub fn rv32_expected_val(op: RiscvOpcode, lhs: u32, rhs: u32) -> Result { + Ok(match op { + RiscvOpcode::And => lhs & rhs, + RiscvOpcode::Andn => lhs & !rhs, + RiscvOpcode::Or => lhs | rhs, + RiscvOpcode::Xor => lhs ^ rhs, + RiscvOpcode::Add => lhs.wrapping_add(rhs), + RiscvOpcode::Sub => lhs.wrapping_sub(rhs), + RiscvOpcode::Eq => (lhs == rhs) as u32, + RiscvOpcode::Neq => (lhs != rhs) as u32, + RiscvOpcode::Sltu => (lhs < rhs) as u32, + RiscvOpcode::Slt => ((lhs as i32) < (rhs as i32)) as u32, + RiscvOpcode::Sll => lhs.wrapping_shl(rhs & 0x1F), + RiscvOpcode::Srl => lhs.wrapping_shr(rhs & 0x1F), + RiscvOpcode::Sra => ((lhs as i32) >> (rhs & 0x1F)) as u32, + RiscvOpcode::Mul => lhs.wrapping_mul(rhs), + RiscvOpcode::Mulhu => (((lhs as u64) * (rhs as u64)) >> 32) as u32, + RiscvOpcode::Mulh => mulh_hi_signed(lhs, rhs), + RiscvOpcode::Mulhsu => mulhsu_hi_signed(lhs, rhs), + RiscvOpcode::Div => div_signed(lhs, rhs), + RiscvOpcode::Rem => rem_signed(lhs, rhs), + RiscvOpcode::Divu => divu(lhs, rhs), + RiscvOpcode::Remu => remu(lhs, rhs), + _ => { + return Err(format!( + "event-table packed: expected value unsupported for opcode={op:?}" + )); + } + }) +} + +pub fn build_rv32_event_table_packed_cols(op: RiscvOpcode, lhs: u32, rhs: u32, val: u32) -> Result, String> { + match op { + RiscvOpcode::And | RiscvOpcode::Andn | RiscvOpcode::Or | RiscvOpcode::Xor => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed {op:?}: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + + let mut lhs_digits = Vec::with_capacity(16); + let mut rhs_digits = Vec::with_capacity(16); + for i in 0..16usize { + lhs_digits.push(F::from_u64(((lhs >> (2 * i)) & 3) as u64)); + rhs_digits.push(F::from_u64(((rhs >> (2 * i)) & 3) as u64)); + } + let mut packed = Vec::with_capacity(34); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.extend_from_slice(&lhs_digits); + packed.extend_from_slice(&rhs_digits); + if packed.len() != 34 { + return Err("packed bitwise: digit packing length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Add => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed ADD: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + let carry = ((lhs as u64 + rhs as u64) >> 32) & 1; + Ok(vec![ + F::from_u64(lhs as u64), + F::from_u64(rhs as u64), + F::from_u64(carry), + ]) + } + RiscvOpcode::Sub => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed SUB: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + let borrow = if lhs < rhs { 1u64 } else { 0u64 }; + Ok(vec![ + F::from_u64(lhs as u64), + F::from_u64(rhs as u64), + F::from_u64(borrow), + ]) + } + RiscvOpcode::Eq | RiscvOpcode::Neq => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!("packed {op:?}: val mismatch (got {val}, expected {expected})")); + } + let borrow = if lhs < rhs { 1u64 } else { 0u64 }; + let diff = lhs.wrapping_sub(rhs); + + let mut packed = Vec::with_capacity(35); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(if borrow == 1 { F::ONE } else { F::ZERO }); + for bit in 0..32usize { + packed.push(if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 35 { + return Err("packed EQ/NEQ: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Sltu => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!("packed SLTU: val mismatch (got {val}, expected {expected})")); + } + let diff = lhs.wrapping_sub(rhs); + let mut packed = Vec::with_capacity(35); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(F::from_u64(diff as u64)); + for bit in 0..32usize { + packed.push(if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 35 { + return Err("packed SLTU: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Slt => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!("packed SLT: val mismatch (got {val}, expected {expected})")); + } + + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + let lhs_b = lhs ^ 0x8000_0000; + let rhs_b = rhs ^ 0x8000_0000; + let diff = lhs_b.wrapping_sub(rhs_b); + + let mut packed = Vec::with_capacity(37); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(F::from_u64(diff as u64)); + packed.push(if lhs_sign == 1 { F::ONE } else { F::ZERO }); + packed.push(if rhs_sign == 1 { F::ONE } else { F::ZERO }); + for bit in 0..32usize { + packed.push(if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 37 { + return Err("packed SLT: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Sll => { + let shamt = rhs & 0x1F; + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed SLL: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + let wide = (lhs as u64) << shamt; + let carry = (wide >> 32) as u32; + + let mut packed = Vec::with_capacity(38); + packed.push(F::from_u64(lhs as u64)); + for bit in 0..5usize { + packed.push(if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + for bit in 0..32usize { + packed.push(if ((carry >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 38 { + return Err("packed SLL: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Srl => { + let shamt = rhs & 0x1F; + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed SRL: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + let rem: u32 = if shamt == 0 { + 0 + } else { + let mask = (1u64 << shamt) - 1; + ((lhs as u64) & mask) as u32 + }; + + let mut packed = Vec::with_capacity(38); + packed.push(F::from_u64(lhs as u64)); + for bit in 0..5usize { + packed.push(if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + for bit in 0..32usize { + packed.push(if ((rem >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 38 { + return Err("packed SRL: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Sra => { + let shamt = rhs & 0x1F; + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed SRA: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + let sign = (lhs >> 31) & 1; + + let lhs_signed: i64 = if sign == 1 { + (lhs as i64) - (1i64 << 32) + } else { + lhs as i64 + }; + let val_signed: i64 = (val as i64) - (sign as i64) * (1i64 << 32); + let pow2: i64 = 1i64 << shamt; + let rem_i64 = lhs_signed - val_signed * pow2; + if rem_i64 < 0 { + return Err("packed SRA: negative remainder".into()); + } + let rem = rem_i64 as u64; + + let mut packed = Vec::with_capacity(38); + packed.push(F::from_u64(lhs as u64)); + for bit in 0..5usize { + packed.push(if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + packed.push(if sign == 1 { F::ONE } else { F::ZERO }); + for bit in 0..31usize { + packed.push(if ((rem >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 38 { + return Err("packed SRA: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Mul => { + let wide = (lhs as u64) * (rhs as u64); + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed MUL: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + let hi = (wide >> 32) as u32; + + let mut packed = Vec::with_capacity(34); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + for bit in 0..32usize { + packed.push(if ((hi >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 34 { + return Err("packed MUL: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Mulhu => { + let wide = (lhs as u64) * (rhs as u64); + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed MULHU: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + let lo = (wide & 0xffff_ffff) as u32; + + let mut packed = Vec::with_capacity(34); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + for bit in 0..32usize { + packed.push(if ((lo >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 34 { + return Err("packed MULHU: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Mulh => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed MULH: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + + let uprod = (lhs as u64) * (rhs as u64); + let lo = (uprod & 0xffff_ffff) as u32; + let hi = (uprod >> 32) as u32; + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + + let diff = + (val as i128) - (hi as i128) + (lhs_sign as i128) * (rhs as i128) + (rhs_sign as i128) * (lhs as i128); + let two32 = 1_i128 << 32; + if diff < 0 || diff % two32 != 0 { + return Err(format!("packed MULH: invalid k (diff={diff})")); + } + let k = (diff / two32) as u32; + if k > 2 { + return Err(format!("packed MULH: k out of range (k={k})")); + } + + let mut packed = Vec::with_capacity(38); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(F::from_u64(hi as u64)); + packed.push(if lhs_sign == 1 { F::ONE } else { F::ZERO }); + packed.push(if rhs_sign == 1 { F::ONE } else { F::ZERO }); + packed.push(F::from_u64(k as u64)); + for bit in 0..32usize { + packed.push(if ((lo >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 38 { + return Err("packed MULH: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Mulhsu => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed MULHSU: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + + let uprod = (lhs as u64) * (rhs as u64); + let lo = (uprod & 0xffff_ffff) as u32; + let hi = (uprod >> 32) as u32; + let lhs_sign = (lhs >> 31) & 1; + + let diff = (val as i128) - (hi as i128) + (lhs_sign as i128) * (rhs as i128); + let two32 = 1_i128 << 32; + if diff < 0 || diff % two32 != 0 { + return Err(format!("packed MULHSU: invalid borrow (diff={diff})")); + } + let borrow = (diff / two32) as u32; + if borrow > 1 { + return Err(format!("packed MULHSU: borrow out of range (borrow={borrow})")); + } + + let mut packed = Vec::with_capacity(37); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(F::from_u64(hi as u64)); + packed.push(if lhs_sign == 1 { F::ONE } else { F::ZERO }); + packed.push(if borrow == 1 { F::ONE } else { F::ZERO }); + for bit in 0..32usize { + packed.push(if ((lo >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 37 { + return Err("packed MULHSU: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Div => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed DIV: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + let lhs_abs = if lhs_sign == 0 { lhs } else { lhs.wrapping_neg() }; + let rhs_abs = if rhs == 0 { + 0u32 + } else if rhs_sign == 0 { + rhs + } else { + rhs.wrapping_neg() + }; + + let rhs_f = F::from_u64(rhs as u64); + let rhs_inv = if rhs_f == F::ZERO { F::ZERO } else { rhs_f.inverse() }; + let rhs_is_zero = rhs == 0; + + let (q_abs, r_abs) = if rhs == 0 { + (0u32, 0u32) + } else { + (lhs_abs / rhs_abs, lhs_abs % rhs_abs) + }; + let q_is_zero = q_abs == 0; + let q_f = F::from_u64(q_abs as u64); + let q_inv = if q_f == F::ZERO { F::ZERO } else { q_f.inverse() }; + + let diff = if rhs == 0 { 0u32 } else { r_abs.wrapping_sub(rhs_abs) }; + + let mut packed = Vec::with_capacity(43); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(F::from_u64(q_abs as u64)); + packed.push(F::from_u64(r_abs as u64)); + packed.push(rhs_inv); + packed.push(if rhs_is_zero { F::ONE } else { F::ZERO }); + packed.push(if lhs_sign == 1 { F::ONE } else { F::ZERO }); + packed.push(if rhs_sign == 1 { F::ONE } else { F::ZERO }); + packed.push(q_inv); + packed.push(if q_is_zero { F::ONE } else { F::ZERO }); + packed.push(F::from_u64(diff as u64)); + for bit in 0..32usize { + packed.push(if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 43 { + return Err("packed DIV: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Rem => { + let expected = rv32_expected_val(op, lhs, rhs)?; + if val != expected { + return Err(format!( + "packed REM: val mismatch (got {val:#x}, expected {expected:#x})" + )); + } + + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + let lhs_abs = if lhs_sign == 0 { lhs } else { lhs.wrapping_neg() }; + let rhs_abs = if rhs == 0 { + 0u32 + } else if rhs_sign == 0 { + rhs + } else { + rhs.wrapping_neg() + }; + + let rhs_f = F::from_u64(rhs as u64); + let rhs_inv = if rhs_f == F::ZERO { F::ZERO } else { rhs_f.inverse() }; + let rhs_is_zero = rhs == 0; + + let (q_abs, r_abs) = if rhs == 0 { + (0u32, 0u32) + } else { + (lhs_abs / rhs_abs, lhs_abs % rhs_abs) + }; + let r_is_zero = r_abs == 0; + let r_f = F::from_u64(r_abs as u64); + let r_inv = if r_f == F::ZERO { F::ZERO } else { r_f.inverse() }; + + let diff = if rhs == 0 { 0u32 } else { r_abs.wrapping_sub(rhs_abs) }; + + let mut packed = Vec::with_capacity(43); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(F::from_u64(q_abs as u64)); + packed.push(F::from_u64(r_abs as u64)); + packed.push(rhs_inv); + packed.push(if rhs_is_zero { F::ONE } else { F::ZERO }); + packed.push(if lhs_sign == 1 { F::ONE } else { F::ZERO }); + packed.push(if rhs_sign == 1 { F::ONE } else { F::ZERO }); + packed.push(r_inv); + packed.push(if r_is_zero { F::ONE } else { F::ZERO }); + packed.push(F::from_u64(diff as u64)); + for bit in 0..32usize { + packed.push(if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 43 { + return Err("packed REM: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Divu => { + let expected_quot = rv32_expected_val(op, lhs, rhs)?; + if val != expected_quot { + return Err(format!( + "packed DIVU: val mismatch (got {val:#x}, expected {expected_quot:#x})" + )); + } + + let rhs_f = F::from_u64(rhs as u64); + let rhs_inv = if rhs_f == F::ZERO { F::ZERO } else { rhs_f.inverse() }; + let rhs_is_zero = rhs == 0; + + let rem = if rhs == 0 { + 0u32 + } else { + let r = ((lhs as u64) % (rhs as u64)) as u32; + // Cross-check with quotient: + let r2 = (lhs as u64).wrapping_sub((rhs as u64).wrapping_mul(val as u64)) as u32; + if r2 != r { + return Err(format!( + "packed DIVU: remainder mismatch (lhs={lhs:#x}, rhs={rhs:#x}, quot={val:#x}, r2={r2:#x}, r={r:#x})" + )); + } + r + }; + let diff = rem.wrapping_sub(rhs); + + let mut packed = Vec::with_capacity(38); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(F::from_u64(rem as u64)); + packed.push(rhs_inv); + packed.push(if rhs_is_zero { F::ONE } else { F::ZERO }); + packed.push(F::from_u64(diff as u64)); + for bit in 0..32usize { + packed.push(if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 38 { + return Err("packed DIVU: length mismatch".into()); + } + Ok(packed) + } + RiscvOpcode::Remu => { + let expected_rem = rv32_expected_val(op, lhs, rhs)?; + if val != expected_rem { + return Err(format!( + "packed REMU: val mismatch (got {val:#x}, expected {expected_rem:#x})" + )); + } + + let rhs_f = F::from_u64(rhs as u64); + let rhs_inv = if rhs_f == F::ZERO { F::ZERO } else { rhs_f.inverse() }; + let rhs_is_zero = rhs == 0; + + let quot = if rhs == 0 { + 0u32 + } else { + (lhs as u64 / rhs as u64) as u32 + }; + if rhs != 0 { + let rem2 = ((lhs as u64) % (rhs as u64)) as u32; + if rem2 != val { + return Err(format!( + "packed REMU: remainder mismatch (lhs={lhs:#x}, rhs={rhs:#x}, quot={quot:#x}, rem={val:#x}, rem2={rem2:#x})" + )); + } + } + let diff = val.wrapping_sub(rhs); + + let mut packed = Vec::with_capacity(38); + packed.push(F::from_u64(lhs as u64)); + packed.push(F::from_u64(rhs as u64)); + packed.push(F::from_u64(quot as u64)); + packed.push(rhs_inv); + packed.push(if rhs_is_zero { F::ONE } else { F::ZERO }); + packed.push(F::from_u64(diff as u64)); + for bit in 0..32usize { + packed.push(if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }); + } + if packed.len() != 38 { + return Err("packed REMU: length mismatch".into()); + } + Ok(packed) + } + _ => Err(format!("event-table packed cols: unsupported opcode={op:?}")), + } +} + +pub fn build_shout_event_table_bus_z( + m: usize, + m_in: usize, + steps: usize, + ell_n: usize, + op: RiscvOpcode, + rows: &[Rv32ShoutEventRow], + x_prefix: &[F], +) -> Result, String> { + if steps == 0 { + return Err("build_shout_event_table_bus_z: steps=0".into()); + } + if rows.len() != steps { + return Err("build_shout_event_table_bus_z: rows/steps mismatch".into()); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_event_table_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + + let base_d = rv32_packed_base_d(op)?; + let ell_addr = ell_n + .checked_add(base_d) + .ok_or_else(|| "build_shout_event_table_bus_z: ell_addr overflow".to_string())?; + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + steps, + core::iter::once((ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_event_table_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + for (j, row) in rows.iter().enumerate() { + z[bus.bus_cell(cols.has_lookup, j)] = F::ONE; + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(row.value as u64); + + let t_idx = m_in + .checked_add(row.row_idx) + .ok_or_else(|| "build_shout_event_table_bus_z: time index overflow".to_string())?; + let mut time_bits = vec![F::ZERO; ell_n]; + for b in 0..ell_n { + let bit = ((t_idx as u64) >> b) & 1; + time_bits[b] = if bit == 1 { F::ONE } else { F::ZERO }; + } + + let lhs = row.lhs as u32; + let rhs = row.rhs as u32; + let val = row.value as u32; + let packed_cols = build_rv32_event_table_packed_cols(op, lhs, rhs, val)?; + if packed_cols.len() != base_d { + return Err("build_shout_event_table_bus_z: packed cols length mismatch".into()); + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + let v = if idx < ell_n { + time_bits[idx] + } else { + packed_cols[idx - ell_n] + }; + z[bus.bus_cell(col_id, j)] = v; + } + } + + Ok(z) +} diff --git a/crates/neo-fold/tests/common/setup.rs b/crates/neo-fold/tests/common/setup.rs new file mode 100644 index 00000000..bad09355 --- /dev/null +++ b/crates/neo-fold/tests/common/setup.rs @@ -0,0 +1,102 @@ +#![allow(dead_code)] + +use std::sync::Arc; + +use neo_ajtai::{s_lincomb, s_mul, setup as ajtai_setup, AjtaiSModule, Commitment as Cmt}; +use neo_ccs::relations::CcsStructure; +use neo_ccs::sparse::{CcsMatrix, CscMat}; +use neo_ccs::Mat; +use neo_fold::shard::CommitMixers; +use neo_math::ring::{cf_inv, Rq as RqEl}; +use neo_math::{D, F}; +use neo_params::NeoParams; +use p3_field::PrimeCharacteristicRing; +use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha8Rng; + +pub type Mixers = CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt>; + +pub fn setup_ajtai_committer(params: &NeoParams, m: usize) -> AjtaiSModule { + let mut rng = ChaCha8Rng::seed_from_u64(7); + let pp = ajtai_setup(&mut rng, D, params.kappa as usize, m).expect("Ajtai setup should succeed"); + AjtaiSModule::new(Arc::new(pp)) +} + +pub fn rot_matrix_to_rq(mat: &Mat) -> RqEl { + debug_assert_eq!(mat.rows(), D); + debug_assert_eq!(mat.cols(), D); + + let mut coeffs = [F::ZERO; D]; + for i in 0..D { + coeffs[i] = mat[(i, 0)]; + } + cf_inv(coeffs) +} + +pub fn default_mixers() -> Mixers { + fn mix_rhos_commits(rhos: &[Mat], cs: &[Cmt]) -> Cmt { + assert!(!cs.is_empty(), "mix_rhos_commits: empty commitments"); + if cs.len() == 1 { + return cs[0].clone(); + } + let rq_els: Vec = rhos.iter().map(rot_matrix_to_rq).collect(); + s_lincomb(&rq_els, cs).expect("s_lincomb should succeed") + } + + fn combine_b_pows(cs: &[Cmt], b: u32) -> Cmt { + assert!(!cs.is_empty(), "combine_b_pows: empty commitments"); + let mut acc = cs[0].clone(); + let mut pow = F::from_u64(b as u64); + for c in cs.iter().skip(1) { + let rq_pow = RqEl::from_field_scalar(pow); + let term = s_mul(&rq_pow, c); + acc.add_inplace(&term); + pow *= F::from_u64(b as u64); + } + acc + } + + CommitMixers { + mix_rhos_commits, + combine_b_pows, + } +} + +/// Test-only helper: widen CCS column count by appending all-zero columns to every matrix. +/// This is used by legacy no-shared packed-bus tests when bus tails need more per-step columns +/// than the optimized RV32 trace core now commits. +pub fn widen_ccs_cols_for_test(ccs: &mut CcsStructure, target_m: usize) { + if target_m <= ccs.m { + return; + } + for mat in &mut ccs.matrices { + match mat { + CcsMatrix::Identity { n } => { + let nrows = *n; + let diag = nrows.min(target_m); + let mut col_ptr = Vec::with_capacity(target_m + 1); + for c in 0..=target_m { + col_ptr.push(c.min(diag)); + } + let row_idx: Vec = (0..diag).collect(); + let vals = vec![F::ONE; diag]; + *mat = CcsMatrix::Csc(CscMat { + nrows, + ncols: target_m, + col_ptr, + row_idx, + vals, + }); + } + CcsMatrix::Csc(csc) => { + if csc.ncols > target_m { + continue; + } + let nnz = *csc.col_ptr.last().unwrap_or(&0); + csc.col_ptr.resize(target_m + 1, nnz); + csc.ncols = target_m; + } + } + } + ccs.m = target_m; +} diff --git a/crates/neo-fold/tests/integration.rs b/crates/neo-fold/tests/integration.rs new file mode 100644 index 00000000..28a5e9f8 --- /dev/null +++ b/crates/neo-fold/tests/integration.rs @@ -0,0 +1,5 @@ +#[path = "common/setup.rs"] +mod common_setup; + +#[path = "suites/integration/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/perf.rs b/crates/neo-fold/tests/perf.rs new file mode 100644 index 00000000..d01a76bb --- /dev/null +++ b/crates/neo-fold/tests/perf.rs @@ -0,0 +1,5 @@ +#[path = "common/setup.rs"] +mod common_setup; + +#[path = "suites/perf/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/redteam.rs b/crates/neo-fold/tests/redteam.rs new file mode 100644 index 00000000..87fa9193 --- /dev/null +++ b/crates/neo-fold/tests/redteam.rs @@ -0,0 +1,2 @@ +#[path = "suites/redteam/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/redteam_riscv.rs b/crates/neo-fold/tests/redteam_riscv.rs new file mode 100644 index 00000000..ae96eed9 --- /dev/null +++ b/crates/neo-fold/tests/redteam_riscv.rs @@ -0,0 +1,2 @@ +#[path = "suites/redteam_riscv/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/regression.rs b/crates/neo-fold/tests/regression.rs new file mode 100644 index 00000000..6c5b0f40 --- /dev/null +++ b/crates/neo-fold/tests/regression.rs @@ -0,0 +1,2 @@ +#[path = "suites/regression/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/riscv_rv32m_mul_divu_remu_prove_verify.rs b/crates/neo-fold/tests/riscv_rv32m_mul_divu_remu_prove_verify.rs deleted file mode 100644 index 5cc22284..00000000 --- a/crates/neo-fold/tests/riscv_rv32m_mul_divu_remu_prove_verify.rs +++ /dev/null @@ -1,161 +0,0 @@ -use neo_fold::riscv_shard::Rv32B1; -use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; -use p3_field::PrimeCharacteristicRing; -use p3_goldilocks::Goldilocks as F; - -#[test] -fn rv32_b1_prove_verify_mul_divu_remu() { - let program = vec![ - // x1 = 7 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 7, - }, - // x2 = 13 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 13, - }, - // x3 = x1 * x2 = 91 - RiscvInstruction::RAlu { - op: RiscvOpcode::Mul, - rd: 3, - rs1: 1, - rs2: 2, - }, - // x4 = x3 / x1 = 13 - RiscvInstruction::RAlu { - op: RiscvOpcode::Divu, - rd: 4, - rs1: 3, - rs2: 1, - }, - // x5 = x3 % x1 = 0 - RiscvInstruction::RAlu { - op: RiscvOpcode::Remu, - rd: 5, - rs1: 3, - rs2: 1, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - - let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program_bytes) - .chunk_size(1) - .max_steps(program.len()) - .prove() - .expect("prove"); - run.verify().expect("verify"); - - let boundary = run.final_boundary_state().expect("final boundary state"); - assert_eq!(boundary.regs_final[3], F::from_u64(91)); - assert_eq!(boundary.regs_final[4], F::from_u64(13)); - assert_eq!(boundary.regs_final[5], F::from_u64(0)); -} - -#[test] -fn rv32_b1_prove_verify_divu_remu_by_zero() { - let dividend = 1234u64; - let program = vec![ - // x1 = dividend - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: dividend as i32, - }, - // x2 = 0 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 0, - }, - // x3 = x1 / x2 (DIVU by zero => 0xffffffff) - RiscvInstruction::RAlu { - op: RiscvOpcode::Divu, - rd: 3, - rs1: 1, - rs2: 2, - }, - // x4 = x1 % x2 (REMU by zero => dividend) - RiscvInstruction::RAlu { - op: RiscvOpcode::Remu, - rd: 4, - rs1: 1, - rs2: 2, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - - let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program_bytes) - .chunk_size(1) - .max_steps(program.len()) - .prove() - .expect("prove"); - run.verify().expect("verify"); - - let boundary = run.final_boundary_state().expect("final boundary state"); - assert_eq!(boundary.regs_final[1], F::from_u64(dividend)); - assert_eq!(boundary.regs_final[3], F::from_u64(u32::MAX as u64)); - assert_eq!(boundary.regs_final[4], F::from_u64(dividend)); -} - -#[test] -fn rv32_b1_prove_verify_div_rem_signed_auto_minimal_includes_sltu() { - // This test specifically exercises the RV32M signed DIV/REM path under `Rv32B1`'s - // default `.shout_auto_minimal()` inference. The step circuit requires SLTU to be - // provisioned so it can do the remainder bound check when divisor != 0. - let program = vec![ - // x1 = -7 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: -7, - }, - // x2 = 3 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 3, - }, - // x3 = x1 / x2 = -2 - RiscvInstruction::RAlu { - op: RiscvOpcode::Div, - rd: 3, - rs1: 1, - rs2: 2, - }, - // x4 = x1 % x2 = -1 - RiscvInstruction::RAlu { - op: RiscvOpcode::Rem, - rd: 4, - rs1: 1, - rs2: 2, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - - let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program_bytes) - .chunk_size(1) - .max_steps(program.len()) - .prove() - .expect("prove"); - run.verify().expect("verify"); - - let boundary = run.final_boundary_state().expect("final boundary state"); - assert_eq!(boundary.regs_final[3], F::from_u64(0xffff_fffe)); // -2 - assert_eq!(boundary.regs_final[4], F::from_u64(0xffff_ffff)); // -1 -} diff --git a/crates/neo-fold/tests/rv32m.rs b/crates/neo-fold/tests/rv32m.rs new file mode 100644 index 00000000..f5a1b254 --- /dev/null +++ b/crates/neo-fold/tests/rv32m.rs @@ -0,0 +1,2 @@ +#[path = "suites/rv32m/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/session.rs b/crates/neo-fold/tests/session.rs new file mode 100644 index 00000000..71867c47 --- /dev/null +++ b/crates/neo-fold/tests/session.rs @@ -0,0 +1,2 @@ +#[path = "suites/session/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/shared_bus.rs b/crates/neo-fold/tests/shared_bus.rs new file mode 100644 index 00000000..bf034576 --- /dev/null +++ b/crates/neo-fold/tests/shared_bus.rs @@ -0,0 +1,5 @@ +#[path = "common/setup.rs"] +mod common_setup; + +#[path = "suites/shared_bus/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/full_folding_integration.rs b/crates/neo-fold/tests/suites/integration/full_folding_integration.rs similarity index 96% rename from crates/neo-fold/tests/full_folding_integration.rs rename to crates/neo-fold/tests/suites/integration/full_folding_integration.rs index 42437409..ab07d287 100644 --- a/crates/neo-fold/tests/full_folding_integration.rs +++ b/crates/neo-fold/tests/suites/integration/full_folding_integration.rs @@ -12,7 +12,7 @@ use neo_fold::finalize::{FinalizeReport, ObligationFinalizer}; use neo_fold::shard::CommitMixers; use neo_fold::shard::{fold_shard_prove, fold_shard_verify, fold_shard_verify_and_finalize, ShardObligations}; use neo_fold::PiCcsError; -use neo_math::{D, K}; +use neo_math::K; use neo_memory::plain::{PlainLutTrace, PlainMemLayout, PlainMemTrace}; use neo_memory::witness::{StepInstanceBundle, StepWitnessBundle}; use neo_memory::MemInit; @@ -312,16 +312,7 @@ fn write_bus_for_chunk( } fn default_mixers() -> Mixers { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } + crate::common_setup::default_mixers() } fn build_single_chunk_inputs() -> ( @@ -399,6 +390,7 @@ fn build_single_chunk_inputs() -> ( // Shared-bus mode: instances are metadata-only; access rows live in the CPU witness. let mem_inst = neo_memory::witness::MemInstance:: { + mem_id: 0, comms: Vec::new(), k: mem_layout.k, d: mem_layout.d, @@ -410,6 +402,7 @@ fn build_single_chunk_inputs() -> ( }; let mem_wit = neo_memory::witness::MemWitness { mats: Vec::new() }; let lut_inst = neo_memory::witness::LutInstance:: { + table_id: lut_table.table_id, comms: Vec::new(), k: lut_table.k, d: lut_table.d, @@ -419,6 +412,8 @@ fn build_single_chunk_inputs() -> ( ell: lut_table.n_side.trailing_zeros() as usize, table_spec: None, table: lut_table.content.clone(), + addr_group: None, + selector_group: None, }; let lut_wit = neo_memory::witness::LutWitness { mats: Vec::new() }; @@ -464,7 +459,7 @@ fn full_folding_integration_single_chunk() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -479,7 +474,7 @@ fn full_folding_integration_single_chunk() { let mut tr_verify = Poseidon2Transcript::new(b"full-fold"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _outputs = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -492,7 +487,7 @@ fn full_folding_integration_single_chunk() { // Print a short summary so it's clear what was enforced. let step0 = &proof.steps[0]; - let mem_me_val = step0.mem.cpu_me_claims_val.len(); + let mem_me_val = step0.mem.val_me_claims.len(); let ccs_me = step0.fold.ccs_out.len(); let total_me = ccs_me; let children = step0.fold.dec_children.len(); @@ -566,6 +561,7 @@ fn full_folding_integration_multi_step_chunk() { }; let mem_inst = neo_memory::witness::MemInstance:: { + mem_id: 0, comms: Vec::new(), k: mem_layout.k, d: mem_layout.d, @@ -577,6 +573,7 @@ fn full_folding_integration_multi_step_chunk() { }; let mem_wit = neo_memory::witness::MemWitness { mats: Vec::new() }; let lut_inst = neo_memory::witness::LutInstance:: { + table_id: lut_table.table_id, comms: Vec::new(), k: lut_table.k, d: lut_table.d, @@ -586,6 +583,8 @@ fn full_folding_integration_multi_step_chunk() { ell: lut_table.n_side.trailing_zeros() as usize, table_spec: None, table: lut_table.content.clone(), + addr_group: None, + selector_group: None, }; let lut_wit = neo_memory::witness::LutWitness { mats: Vec::new() }; @@ -622,7 +621,7 @@ fn full_folding_integration_multi_step_chunk() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-multi-step-chunk"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -637,7 +636,7 @@ fn full_folding_integration_multi_step_chunk() { let mut tr_verify = Poseidon2Transcript::new(b"full-fold-multi-step-chunk"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -655,7 +654,7 @@ fn tamper_batched_claimed_sum_fails() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-tamper-claim"); let mut proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -673,7 +672,7 @@ fn tamper_batched_claimed_sum_fails() { let mut tr_verify = Poseidon2Transcript::new(b"full-fold-tamper-claim"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let result = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -692,7 +691,7 @@ fn tamper_me_opening_fails() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-tamper-me"); let mut proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -714,7 +713,7 @@ fn tamper_me_opening_fails() { let mut tr_verify = Poseidon2Transcript::new(b"full-fold-tamper-me"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let result = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -733,7 +732,7 @@ fn tamper_shout_addr_pre_round_poly_fails() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-tamper-shout-addr-pre"); let mut proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -758,7 +757,7 @@ fn tamper_shout_addr_pre_round_poly_fails() { let mut tr_verify = Poseidon2Transcript::new(b"full-fold-tamper-shout-addr-pre"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let result = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -782,7 +781,7 @@ fn tamper_twist_val_eval_round_poly_fails() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-tamper-twist-val-eval-rounds"); let mut proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -807,7 +806,7 @@ fn tamper_twist_val_eval_round_poly_fails() { let mut tr_verify = Poseidon2Transcript::new(b"full-fold-tamper-twist-val-eval-rounds"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let result = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -829,7 +828,7 @@ fn missing_val_fold_fails() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-missing-val-fold"); let mut proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -842,15 +841,15 @@ fn missing_val_fold_fails() { .expect("prove should succeed"); assert!( - proof.steps[0].val_fold.is_some(), + !proof.steps[0].val_fold.is_empty(), "fixture should produce val_fold when Twist is present" ); - proof.steps[0].val_fold = None; + proof.steps[0].val_fold.clear(); let mut tr_verify = Poseidon2Transcript::new(b"full-fold-missing-val-fold"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let result = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -887,7 +886,7 @@ fn verify_and_finalize_receives_val_lane() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-finalizer"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -903,7 +902,7 @@ fn verify_and_finalize_receives_val_lane() { let steps_public = [StepInstanceBundle::from(&step_bundle)]; let mut fin = RequireValLane; fold_shard_verify_and_finalize( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -935,7 +934,7 @@ fn main_only_finalizer_is_rejected_when_val_lane_present() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-finalizer-main-only"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -951,7 +950,7 @@ fn main_only_finalizer_is_rejected_when_val_lane_present() { let steps_public = [StepInstanceBundle::from(&step_bundle)]; let mut fin = MainOnly; let res = fold_shard_verify_and_finalize( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -992,7 +991,7 @@ fn wrong_shout_lookup_value_witness_fails() { let mut tr_prove = Poseidon2Transcript::new(b"full-fold-wrong-shout-bus"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -1007,7 +1006,7 @@ fn wrong_shout_lookup_value_witness_fails() { let mut tr_verify = Poseidon2Transcript::new(b"full-fold-wrong-shout-bus"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let res = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, diff --git a/crates/neo-fold/tests/suites/integration/mod.rs b/crates/neo-fold/tests/suites/integration/mod.rs new file mode 100644 index 00000000..8fbe8a19 --- /dev/null +++ b/crates/neo-fold/tests/suites/integration/mod.rs @@ -0,0 +1,9 @@ +mod full_folding_integration; +mod output_binding; +mod rectangular_ccs_e2e; +mod riscv_proof_integration; +mod riscv_trace_wiring_ccs_e2e; +mod riscv_trace_wiring_mode_e2e; +mod riscv_trace_wiring_runner_e2e; +mod shard_continuation_extend_and_fold; +mod streaming_dec_equivalence; diff --git a/crates/neo-fold/tests/suites/integration/output_binding.rs b/crates/neo-fold/tests/suites/integration/output_binding.rs new file mode 100644 index 00000000..4e51f5ab --- /dev/null +++ b/crates/neo-fold/tests/suites/integration/output_binding.rs @@ -0,0 +1,4 @@ +#[path = "output_binding_e2e.rs"] +mod output_binding_e2e; +#[path = "output_binding_tests.rs"] +mod output_binding_tests; diff --git a/crates/neo-fold/tests/output_binding_e2e.rs b/crates/neo-fold/tests/suites/integration/output_binding_e2e.rs similarity index 93% rename from crates/neo-fold/tests/output_binding_e2e.rs rename to crates/neo-fold/tests/suites/integration/output_binding_e2e.rs index 3ff788f9..00ec81e8 100644 --- a/crates/neo-fold/tests/output_binding_e2e.rs +++ b/crates/neo-fold/tests/suites/integration/output_binding_e2e.rs @@ -11,7 +11,7 @@ use neo_fold::output_binding::OutputBindingConfig; use neo_fold::pi_ccs::FoldingMode; use neo_fold::shard::{fold_shard_prove_with_output_binding, fold_shard_verify_with_output_binding, CommitMixers}; use neo_fold::PiCcsError; -use neo_math::{D, F, K}; +use neo_math::{F, K}; use neo_memory::cpu::build_bus_layout_for_instances; use neo_memory::cpu::constraints::{extend_ccs_with_shared_cpu_bus_constraints, TwistCpuBinding}; use neo_memory::output_check::ProgramIO; @@ -21,6 +21,8 @@ use neo_params::NeoParams; use neo_transcript::{Poseidon2Transcript, Transcript}; use p3_field::PrimeCharacteristicRing; +type Mixers = CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt>; + #[derive(Clone, Copy, Default)] struct DummyCommit; @@ -41,17 +43,8 @@ impl SModuleHomomorphism for DummyCommit { } } -fn default_mixers() -> CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt> { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } +fn default_mixers() -> Mixers { + crate::common_setup::default_mixers() } fn empty_identity_first_r1cs_ccs(n: usize) -> CcsStructure { @@ -105,6 +98,7 @@ fn output_binding_e2e_wrong_claim_fails() -> Result<(), PiCcsError> { let const_one_col = 0usize; let mem_inst = MemInstance:: { + mem_id: 0, comms: Vec::new(), k: 4, d: 1, @@ -205,7 +199,7 @@ fn output_binding_e2e_wrong_claim_fails() -> Result<(), PiCcsError> { let mut tr_prove = Poseidon2Transcript::new(b"output-binding-e2e"); let proof = fold_shard_prove_with_output_binding( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -220,7 +214,7 @@ fn output_binding_e2e_wrong_claim_fails() -> Result<(), PiCcsError> { let mut tr_verify_ok = Poseidon2Transcript::new(b"output-binding-e2e"); let _outputs_ok = fold_shard_verify_with_output_binding( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify_ok, ¶ms, &ccs, @@ -233,7 +227,7 @@ fn output_binding_e2e_wrong_claim_fails() -> Result<(), PiCcsError> { let mut tr_verify_bad = Poseidon2Transcript::new(b"output-binding-e2e"); let res = fold_shard_verify_with_output_binding( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify_bad, ¶ms, &ccs, diff --git a/crates/neo-fold/tests/output_binding_tests.rs b/crates/neo-fold/tests/suites/integration/output_binding_tests.rs similarity index 100% rename from crates/neo-fold/tests/output_binding_tests.rs rename to crates/neo-fold/tests/suites/integration/output_binding_tests.rs diff --git a/crates/neo-fold/tests/rectangular_ccs_e2e.rs b/crates/neo-fold/tests/suites/integration/rectangular_ccs_e2e.rs similarity index 100% rename from crates/neo-fold/tests/rectangular_ccs_e2e.rs rename to crates/neo-fold/tests/suites/integration/rectangular_ccs_e2e.rs diff --git a/crates/neo-fold/tests/riscv_proof_integration.rs b/crates/neo-fold/tests/suites/integration/riscv_proof_integration.rs similarity index 100% rename from crates/neo-fold/tests/riscv_proof_integration.rs rename to crates/neo-fold/tests/suites/integration/riscv_proof_integration.rs diff --git a/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_ccs_e2e.rs b/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_ccs_e2e.rs new file mode 100644 index 00000000..a58cd977 --- /dev/null +++ b/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_ccs_e2e.rs @@ -0,0 +1,51 @@ +use neo_ajtai::AjtaiSModule; +use neo_ccs::relations::check_ccs_rowwise_zero; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::session::FoldingSession; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_vm_trace::trace_program; + +#[test] +fn riscv_trace_wiring_ccs_single_step_prove_verify() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + check_ccs_rowwise_zero(&ccs, &x, &w).expect("trace CCS rowwise zero"); + + let mut session = FoldingSession::::new_ajtai_seeded(FoldingMode::Optimized, &ccs, [9u8; 32]) + .expect("new_ajtai_seeded"); + session.add_step_io(&ccs, &x, &w).expect("add_step_io"); + session + .prove_and_verify_collected(&ccs) + .expect("prove_and_verify_collected"); +} diff --git a/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_mode_e2e.rs b/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_mode_e2e.rs new file mode 100644 index 00000000..c7dc73e7 --- /dev/null +++ b/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_mode_e2e.rs @@ -0,0 +1,187 @@ +#![allow(non_snake_case)] + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn trace_mode_program_bytes() -> Vec { + // Program: ADDI x1, x0, 1; ADDI x2, x1, 2; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 2, + }, + RiscvInstruction::Halt, + ]; + encode_program(&program) +} + +#[test] +fn rv32_trace_wiring_mode_prove_verify() { + let program_bytes = trace_mode_program_bytes(); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .reg_init_u32(/*reg=*/ 3, /*value=*/ 9) + .ram_init_u32(/*addr=*/ 16, /*value=*/ 7) + .min_trace_len(8) + .prove() + .expect("trace wiring prove"); + + run.verify().expect("trace wiring verify"); + + assert_eq!(run.fold_count(), 1, "trace-wiring mode should produce one folding step"); + assert_eq!(run.trace_len(), 3, "active trace length mismatch"); + assert_eq!( + run.exec_table().rows.len(), + 8, + "trace_min_len should set padded trace length" + ); + assert_eq!(run.layout().t, 8, "layout t should match padded trace length"); +} + +#[test] +fn rv32_trace_wiring_mode_does_not_force_pow2_padding() { + let program_bytes = trace_mode_program_bytes(); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .min_trace_len(1) + .prove() + .expect("trace wiring prove"); + + run.verify().expect("trace wiring verify"); + + assert_eq!(run.trace_len(), 3, "active trace length mismatch"); + assert_eq!( + run.exec_table().rows.len(), + 3, + "trace-wiring mode should keep unpadded trace length when min bound is smaller" + ); + assert_eq!(run.layout().t, 3, "layout t should match unpadded trace length"); +} + +#[test] +fn rv32_trace_wiring_mode_ram_output_binding_prove_verify() { + // Program: ADDI x1, x0, 7; SW x1, 16(x0); HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 7, + }, + RiscvInstruction::Store { + op: neo_memory::riscv::lookups::RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 16, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .output_claim(/*addr=*/ 16, /*value=*/ neo_math::F::from_u64(7)) + .prove() + .expect("trace wiring prove with RAM output binding"); + + run.verify() + .expect("trace wiring verify with RAM output binding"); +} + +#[test] +fn rv32_trace_wiring_mode_reg_output_binding_prove_verify() { + // Program: ADDI x2, x0, 3; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .reg_output_claim(/*reg=*/ 2, /*value=*/ neo_math::F::from_u64(3)) + .prove() + .expect("trace wiring prove with REG output binding"); + + run.verify() + .expect("trace wiring verify with REG output binding"); +} + +#[test] +fn rv32_trace_wiring_mode_wrong_reg_output_claim_fails_verify() { + // Program: ADDI x2, x0, 3; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .reg_output_claim(/*reg=*/ 2, /*value=*/ neo_math::F::from_u64(4)) + .prove() + .expect("trace wiring prove with wrong REG claim still produces a proof"); + + let err = run + .verify() + .expect_err("trace wiring verify should fail for wrong REG output claim"); + let msg = format!("{err}"); + assert!(msg.contains("output sumcheck failed"), "unexpected verify error: {msg}"); +} + +#[test] +fn rv32_trace_wiring_mode_allows_without_insecure_ack() { + let program_bytes = trace_mode_program_bytes(); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace-wiring mode should not require insecure benchmark-only ack"); + run.verify() + .expect("trace-wiring proof should verify without insecure benchmark-only ack"); +} + +#[test] +fn rv32_trace_wiring_mode_chunked_ivc() { + let program_bytes = trace_mode_program_bytes(); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(2) + .prove() + .expect("trace wiring prove with chunked ivc"); + + run.verify().expect("trace wiring verify with chunked ivc"); + assert_eq!(run.fold_count(), 2, "expected two fold steps with trace_chunk_rows=2"); +} + +#[test] +fn rv32_trace_shout_override_must_superset_inferred_set() { + let program_bytes = trace_mode_program_bytes(); + let err = match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .shout_ops([RiscvOpcode::Xor]) + .prove() + { + Ok(_) => panic!("shout override that misses required tables must fail"), + Err(e) => e, + }; + let msg = err.to_string(); + assert!( + msg.contains("superset") && msg.contains("Add"), + "unexpected error message: {msg}" + ); +} diff --git a/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_runner_e2e.rs b/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_runner_e2e.rs new file mode 100644 index 00000000..2ab92a2f --- /dev/null +++ b/crates/neo-fold/tests/suites/integration/riscv_trace_wiring_runner_e2e.rs @@ -0,0 +1,841 @@ +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::K; +use neo_memory::riscv::ccs::TraceShoutBusSpec; +use neo_memory::riscv::lookups::{ + encode_program, BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode, RiscvShoutTables, PROG_ID, RAM_ID, + REG_ID, +}; +use neo_memory::riscv::trace::{ + rv32_decode_lookup_backed_cols, rv32_is_decode_lookup_table_id, rv32_width_lookup_backed_cols, + Rv32DecodeSidecarLayout, Rv32WidthSidecarLayout, +}; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn rv32_trace_wiring_runner_prove_verify() { + // Program: ADDI x1, x0, 1; ADDI x2, x1, 2; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .min_trace_len(1) + .prove() + .expect("trace wiring prove"); + + run.verify().expect("trace wiring verify"); + + assert_eq!(run.fold_count(), 1, "trace runner should produce one folding step"); + assert_eq!(run.trace_len(), 3, "active trace length mismatch"); + assert_eq!( + run.exec_table().rows.len(), + 3, + "exec table should not be padded to next power-of-two" + ); + assert_eq!( + run.layout().t, + run.exec_table().rows.len(), + "layout.t should match exec rows" + ); + + let steps_public = run.steps_public(); + assert_eq!(steps_public.len(), 1, "trace runner should expose one step instance"); + let mut mem_ids: Vec = steps_public[0] + .mem_insts + .iter() + .map(|inst| inst.mem_id) + .collect(); + mem_ids.sort_unstable(); + let mut expected_mem_ids = vec![PROG_ID.0, RAM_ID.0, REG_ID.0]; + expected_mem_ids.retain(|&id| id != RAM_ID.0); + expected_mem_ids.sort_unstable(); + assert_eq!( + mem_ids, expected_mem_ids, + "trace runner should default to used-sidecar instantiation (no RAM sidecar when unused)" + ); + assert_eq!( + run.used_memory_ids(), + expected_mem_ids.as_slice(), + "run artifact should record auto-derived S_memory" + ); + let add_table_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add).0; + let used_lookup_ids = run.used_shout_table_ids(); + assert!( + used_lookup_ids.contains(&add_table_id), + "run artifact should include opcode-backed S_lookup tables" + ); + let decode_lookup_count = used_lookup_ids + .iter() + .copied() + .filter(|table_id| rv32_is_decode_lookup_table_id(*table_id)) + .count(); + assert_eq!( + decode_lookup_count, + rv32_decode_lookup_backed_cols(&Rv32DecodeSidecarLayout::new()).len(), + "run artifact should include decode lookup families in S_lookup" + ); +} + +#[test] +fn rv32_trace_wiring_runner_reg_output_binding_prove_verify() { + // Program: ADDI x2, x0, 3; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .reg_output_claim(/*reg=*/ 2, /*expected=*/ neo_math::F::from_u64(3)) + .prove() + .expect("trace wiring prove with reg output binding"); + + run.verify() + .expect("trace wiring verify with reg output binding"); +} + +#[test] +fn rv32_trace_wiring_runner_allows_without_insecure_ack() { + let program = vec![RiscvInstruction::Halt]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace wiring should no longer require insecure benchmark-only ack"); + run.verify() + .expect("trace wiring proof should verify without insecure benchmark-only ack"); +} + +#[test] +fn rv32_trace_wiring_runner_prove_verify_without_insecure_ack() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .min_trace_len(1) + .prove() + .expect("trace wiring should prove without insecure benchmark-only ack"); + + run.verify() + .expect("trace wiring proof should verify without insecure benchmark-only ack"); +} + +#[test] +fn rv32_trace_wiring_runner_shared_bus_default_and_legacy_fallback_differ() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let run_shared = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .min_trace_len(1) + .prove() + .expect("trace wiring prove"); + + let legacy_err = match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .shared_cpu_bus(false) + .min_trace_len(1) + .prove() + { + Ok(_) => panic!("legacy no-shared fallback must be rejected"), + Err(e) => e, + }; + + let msg = legacy_err.to_string(); + assert!( + msg.contains("no-shared fallback is removed"), + "unexpected no-shared rejection error: {msg}" + ); + assert_eq!( + run_shared.ccs_num_variables(), + run_shared.layout().m, + "shared-bus trace layout width must match CCS width" + ); +} + +#[test] +fn rv32_trace_wiring_runner_shout_override_must_superset_inferred_set() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let err = match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .shout_ops([RiscvOpcode::Xor]) + .prove() + { + Ok(_) => panic!("shout override that misses required tables must fail"), + Err(e) => e, + }; + let msg = err.to_string(); + assert!( + msg.contains("superset") && msg.contains("Add"), + "unexpected error message: {msg}" + ); +} + +#[test] +fn rv32_trace_wiring_runner_rejects_extra_shout_spec_without_table_spec() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let err = match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .extra_shout_bus_specs([TraceShoutBusSpec { + table_id: 1000, + ell_addr: 13, + n_vals: 1usize, + }]) + .prove() + { + Ok(_) => panic!("extra shout geometry without table spec must fail"), + Err(e) => e, + }; + let msg = err.to_string(); + assert!( + msg.contains("extra_shout_bus_specs includes table_id=1000 without a table spec"), + "unexpected error message: {msg}" + ); +} + +#[test] +fn rv32_trace_wiring_runner_accepts_extra_shout_spec_with_matching_table_spec() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .extra_lut_table_spec(1000, neo_memory::witness::LutTableSpec::IdentityU32) + .extra_shout_bus_specs([TraceShoutBusSpec { + table_id: 1000, + ell_addr: 32, + n_vals: 1usize, + }]) + .prove() + .expect("trace wiring prove with extra table/spec"); + run.verify() + .expect("trace wiring verify with extra table/spec"); + + assert!( + run.used_shout_table_ids().contains(&1000), + "run should record extra table_id in used shout set" + ); +} + +#[test] +fn rv32_trace_wiring_runner_rejects_extra_table_spec_colliding_with_opcode_table() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let add_table_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add).0; + let err = match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .extra_lut_table_spec(add_table_id, neo_memory::witness::LutTableSpec::IdentityU32) + .prove() + { + Ok(_) => panic!("extra_lut_table_spec collision with inferred opcode table must fail"), + Err(e) => e, + }; + let msg = err.to_string(); + assert!( + msg.contains("extra_lut_table_spec collides with existing table_id"), + "unexpected error message: {msg}" + ); +} + +#[test] +fn rv32_trace_wiring_runner_rejects_max_steps_above_trace_cap() { + let program = vec![RiscvInstruction::Halt]; + let program_bytes = encode_program(&program); + + let err = match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .max_steps((1usize << 20) + 1) + .prove() + { + Ok(_) => panic!("max_steps above trace cap must be rejected"), + Err(e) => e, + }; + + let msg = err.to_string(); + assert!( + msg.contains("max_steps=") && msg.contains("trace-mode hard cap"), + "unexpected error message: {msg}" + ); +} + +#[test] +fn rv32_trace_wiring_runner_rejects_min_trace_len_above_trace_cap() { + let program = vec![RiscvInstruction::Halt]; + let program_bytes = encode_program(&program); + + let err = match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .min_trace_len((1usize << 20) + 1) + .prove() + { + Ok(_) => panic!("min_trace_len above trace cap must be rejected"), + Err(e) => e, + }; + + let msg = err.to_string(); + assert!( + msg.contains("min_trace_len=") && msg.contains("trace-mode hard cap"), + "unexpected error message: {msg}" + ); +} + +#[test] +fn rv32_trace_wiring_runner_chunked_ivc_step_linking() { + // Program: ADDI x1, x0, 1; ADDI x2, x1, 2; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(2) + .prove() + .expect("trace wiring prove with chunked ivc"); + + run.verify().expect("trace wiring verify with chunked ivc"); + + assert_eq!( + run.fold_count(), + 2, + "chunk_rows=2 over 3 rows should produce two fold steps" + ); + let steps = run.steps_public(); + assert_eq!(steps.len(), 2, "expected two public steps"); + + let layout = run.layout(); + let prev = &steps[0].mcs_inst.x; + let cur = &steps[1].mcs_inst.x; + assert_eq!( + prev[layout.pc_final], cur[layout.pc0], + "trace step linking must enforce pc_final -> pc0 across steps" + ); +} + +#[test] +fn rv32_trace_wiring_runner_chunked_ivc_batches_no_shared_val_lanes_per_mem() { + // Program: ADDI x1, x0, 1; ADDI x2, x1, 2; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(2) + .prove() + .expect("trace wiring prove with chunked ivc"); + run.verify().expect("trace wiring verify with chunked ivc"); + + let steps_public = run.steps_public(); + let shard_proof = run.proof(); + assert_eq!(steps_public.len(), 2, "expected two public steps"); + assert_eq!(shard_proof.steps.len(), 2, "expected two proof steps"); + + // Step 0 (shared-bus): one current CPU val claim. + let proof_step0 = &shard_proof.steps[0]; + assert_eq!( + proof_step0.mem.val_me_claims.len(), + 1, + "step0(shared) must emit one current CPU val claim" + ); + assert_eq!( + proof_step0.val_fold.len(), + 1, + "step0(shared) must emit one val-fold proof" + ); + + // Step 1 (shared-bus): val claims are [current_cpu, previous_cpu], each with its own fold proof. + let proof_step1 = &shard_proof.steps[1]; + assert_eq!( + proof_step1.mem.val_me_claims.len(), + 2, + "step1(shared) must emit current+previous CPU val claims" + ); + assert_eq!( + proof_step1.val_fold.len(), + 2, + "step1(shared) must emit one val-fold proof per claim" + ); +} + +#[test] +fn rv32_trace_wiring_runner_wb_wp_folds_are_emitted_and_required() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace wiring prove"); + run.verify().expect("trace wiring verify"); + + let proof = run.proof().clone(); + assert_eq!(proof.steps.len(), 1, "expected one step proof"); + assert!( + !proof.steps[0].mem.wb_me_claims.is_empty(), + "expected WB ME claims for RV32 trace route-A" + ); + assert!( + !proof.steps[0].mem.wp_me_claims.is_empty(), + "expected WP ME claims for RV32 trace route-A" + ); + assert!( + !proof.steps[0].wb_fold.is_empty(), + "expected wb_fold proofs for RV32 trace route-A" + ); + assert!( + !proof.steps[0].wp_fold.is_empty(), + "expected wp_fold proofs for RV32 trace route-A" + ); + + let mut proof_missing_wb = proof.clone(); + proof_missing_wb.steps[0].wb_fold.clear(); + assert!( + run.verify_proof(&proof_missing_wb).is_err(), + "missing wb_fold must fail verification" + ); + + let mut proof_missing_wp = proof.clone(); + proof_missing_wp.steps[0].wp_fold.clear(); + assert!( + run.verify_proof(&proof_missing_wp).is_err(), + "missing wp_fold must fail verification" + ); +} + +#[test] +fn rv32_trace_wiring_runner_decode_openings_are_embedded_in_wp_and_required() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace wiring prove"); + run.verify().expect("trace wiring verify"); + + let proof = run.proof().clone(); + assert_eq!(proof.steps.len(), 1, "expected one step proof"); + assert_eq!(proof.steps[0].mem.wp_me_claims.len(), 1, "expected one WP ME claim"); + let mut proof_missing_decode_me = proof.clone(); + let decode_layout = Rv32DecodeSidecarLayout::new(); + let decode_open_cols = rv32_decode_lookup_backed_cols(&decode_layout); + let me = &mut proof_missing_decode_me.steps[0].mem.wp_me_claims[0]; + let decode_start = me + .y_scalars + .len() + .checked_sub(decode_open_cols.len()) + .expect("decode openings must be appended to WP ME tail"); + let decode_idx = decode_open_cols + .iter() + .position(|&c| c == decode_layout.op_alu_imm) + .expect("decode opening column must be present"); + me.y_scalars[decode_start + decode_idx] += K::ONE; + assert!( + run.verify_proof(&proof_missing_decode_me).is_err(), + "tampered decode lookup opening embedded in WP ME must fail verification" + ); +} + +#[test] +fn rv32_trace_wiring_runner_width_openings_on_wp_are_required() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace wiring prove"); + run.verify().expect("trace wiring verify"); + + let proof = run.proof().clone(); + assert_eq!(proof.steps.len(), 1, "expected one step proof"); + assert_eq!(proof.steps[0].mem.wp_me_claims.len(), 1, "expected one WP ME claim"); + + let mut proof_tampered_width_open = proof.clone(); + let width_layout = Rv32WidthSidecarLayout::new(); + let width_open_cols = rv32_width_lookup_backed_cols(&width_layout); + let wp_me = &mut proof_tampered_width_open.steps[0].mem.wp_me_claims[0]; + let width_open_start = wp_me + .y_scalars + .len() + .checked_sub(width_open_cols.len()) + .expect("width openings must be appended to WP ME tail"); + let width_idx = width_open_cols + .iter() + .position(|&c| c == width_layout.rs2_low_bit[0]) + .expect("width opening column must be present"); + wp_me.y_scalars[width_open_start + width_idx] += K::ONE; + assert!( + run.verify_proof(&proof_tampered_width_open).is_err(), + "tampered width lookup opening embedded in WP ME must fail verification" + ); +} + +#[test] +fn rv32_trace_wiring_runner_control_claims_are_emitted_and_required() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace wiring prove"); + run.verify().expect("trace wiring verify"); + + let proof = run.proof().clone(); + assert_eq!(proof.steps.len(), 1, "expected one step proof"); + + let labels = &proof.steps[0].batched_time.labels; + let find_w4 = |label: &'static [u8]| -> usize { + labels + .iter() + .position(|l| *l == label) + .expect("missing required control stage claim label in batched_time") + }; + let control_linear_idx = find_w4(b"control/next_pc_linear"); + let control_control_idx = find_w4(b"control/next_pc_control"); + let control_branch_idx = find_w4(b"control/branch_semantics"); + let _control_writeback_idx = find_w4(b"control/writeback"); + assert!( + control_linear_idx < labels.len() && control_control_idx < labels.len() && control_branch_idx < labels.len(), + "control stage labels must be present in batched_time" + ); + + let mut proof_missing_control_claim = proof.clone(); + let _ = proof_missing_control_claim.steps[0] + .batched_time + .claimed_sums + .remove(control_control_idx); + let _ = proof_missing_control_claim.steps[0] + .batched_time + .degree_bounds + .remove(control_control_idx); + let _ = proof_missing_control_claim.steps[0] + .batched_time + .labels + .remove(control_control_idx); + let _ = proof_missing_control_claim.steps[0] + .batched_time + .round_polys + .remove(control_control_idx); + assert!( + run.verify_proof(&proof_missing_control_claim).is_err(), + "missing control/next_pc_control claim artifact must fail verification" + ); + + let mut proof_tampered_control_round = proof.clone(); + let coeff = proof_tampered_control_round.steps[0] + .batched_time + .round_polys[control_control_idx] + .get_mut(0) + .and_then(|round| round.get_mut(0)) + .expect("control/next_pc_control first-round coeff must exist"); + *coeff += K::ONE; + assert!( + run.verify_proof(&proof_tampered_control_round).is_err(), + "tampered control/next_pc_control round polynomial must fail verification" + ); +} + +#[test] +fn rv32_trace_wiring_runner_rejects_zero_chunk_rows() { + let program = vec![RiscvInstruction::Halt]; + let program_bytes = encode_program(&program); + + let err = match Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(0) + .prove() + { + Ok(_) => panic!("chunk_rows=0 must be rejected"), + Err(e) => e, + }; + + let msg = err.to_string(); + assert!(msg.contains("chunk_rows"), "unexpected error message: {msg}"); +} + +#[test] +fn rv32_trace_wiring_runner_rejects_amo_via_wb_decode_scope_lock() { + // Program includes one AMO row. In Tier 2.1 trace mode this is rejected by WB/decode stage + // decode residuals (scope lock), not by the N0 main-trace CCS. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 5, + }, + RiscvInstruction::Amo { + op: RiscvMemOp::AmoaddW, + rd: 2, + rs1: 0, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + assert!( + Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .is_err(), + "AMO must be rejected in Tier 2.1 trace mode via WB/decode stage scope lock" + ); +} + +fn prove_verify_trace_program(program: Vec) { + let program_bytes = encode_program(&program); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .min_trace_len(program.len()) + .max_steps(program.len()) + .prove() + .expect("trace wiring prove"); + run.verify().expect("trace wiring verify"); +} + +#[test] +fn rv32_trace_wiring_runner_accepts_mixed_addi_andi_halt() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::And, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + prove_verify_trace_program(program); +} + +#[test] +fn rv32_trace_wiring_runner_accepts_mixed_addi_ori_halt() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + prove_verify_trace_program(program); +} + +#[test] +fn rv32_trace_wiring_runner_accepts_mixed_with_srai_halt() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sra, + rd: 2, + rs1: 1, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + prove_verify_trace_program(program); +} + +#[test] +fn rv32_trace_wiring_runner_accepts_full_mixed_sequence_halt() { + let mut program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::And, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 4, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Slt, + rd: 6, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sltu, + rd: 7, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sll, + rd: 8, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Srl, + rd: 9, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sra, + rd: 10, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Branch { + cond: BranchCondition::Ne, + rs1: 0, + rs2: 0, + imm: 8, + }, + ]; + program.push(RiscvInstruction::Halt); + prove_verify_trace_program(program.clone()); +} diff --git a/crates/neo-fold/tests/shard_continuation_extend_and_fold.rs b/crates/neo-fold/tests/suites/integration/shard_continuation_extend_and_fold.rs similarity index 100% rename from crates/neo-fold/tests/shard_continuation_extend_and_fold.rs rename to crates/neo-fold/tests/suites/integration/shard_continuation_extend_and_fold.rs diff --git a/crates/neo-fold/tests/streaming_dec_equivalence.rs b/crates/neo-fold/tests/suites/integration/streaming_dec_equivalence.rs similarity index 91% rename from crates/neo-fold/tests/streaming_dec_equivalence.rs rename to crates/neo-fold/tests/suites/integration/streaming_dec_equivalence.rs index ab025720..207296f0 100644 --- a/crates/neo-fold/tests/streaming_dec_equivalence.rs +++ b/crates/neo-fold/tests/suites/integration/streaming_dec_equivalence.rs @@ -25,27 +25,7 @@ fn create_identity_ccs(n: usize) -> CcsStructure { } fn mixers() -> CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt> { - fn mix_rhos_commits(_rhos: &[Mat], cs: &[Cmt]) -> Cmt { - assert_eq!(cs.len(), 1, "test mixers expect k=1"); - cs[0].clone() - } - fn combine_b_pows(cs: &[Cmt], b: u32) -> Cmt { - assert!(!cs.is_empty(), "combine_b_pows: empty commitments"); - let mut acc = cs[0].clone(); - let b_f = F::from_u64(b as u64); - let mut pow = b_f; - for i in 1..cs.len() { - for (a, &x) in acc.data.iter_mut().zip(cs[i].data.iter()) { - *a += x * pow; - } - pow *= b_f; - } - acc - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } + crate::common_setup::default_mixers() } fn build_single_step_bundle(params: &NeoParams, l: &AjtaiSModule, m: usize) -> StepWitnessBundle { diff --git a/crates/neo-fold/tests/memory_adversarial_tests.rs b/crates/neo-fold/tests/suites/perf/memory_adversarial_tests.rs similarity index 93% rename from crates/neo-fold/tests/memory_adversarial_tests.rs rename to crates/neo-fold/tests/suites/perf/memory_adversarial_tests.rs index 95232f05..b8a5fb22 100644 --- a/crates/neo-fold/tests/memory_adversarial_tests.rs +++ b/crates/neo-fold/tests/suites/perf/memory_adversarial_tests.rs @@ -53,16 +53,7 @@ fn setup_ajtai_pp(m: usize, seed: u64) -> AjtaiSModule { } fn default_mixers() -> Mixers { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } + crate::common_setup::default_mixers() } fn create_identity_ccs(n: usize) -> CcsStructure { @@ -85,6 +76,7 @@ fn create_mcs_from_z( } fn make_twist_instance( + mem_id: u32, layout: &PlainMemLayout, init: MemInit, steps: usize, @@ -95,6 +87,7 @@ fn make_twist_instance( let ell = layout.n_side.trailing_zeros() as usize; ( neo_memory::witness::MemInstance { + mem_id, comms: Vec::new(), k: layout.k, d: layout.d, @@ -246,7 +239,7 @@ fn memory_cross_step_read_consistency() { inc_at_write_addr: vec![F::from_u64(42)], // 42 - 0 = 42 }; let mem_init = MemInit::Zero; - let (mem_inst, mem_wit) = make_twist_instance(&mem_layout, mem_init, 1); + let (mem_inst, mem_wit) = make_twist_instance(0, &mem_layout, mem_init, 1); steps.push(create_step_with_twist_bus( ¶ms, &ccs, @@ -270,7 +263,7 @@ fn memory_cross_step_read_consistency() { }; // Memory state after step 0: addr[0] = 42 let mem_init = MemInit::Sparse(vec![(0, F::from_u64(42))]); - let (mem_inst, mem_wit) = make_twist_instance(&mem_layout, mem_init, 1); + let (mem_inst, mem_wit) = make_twist_instance(0, &mem_layout, mem_init, 1); steps.push(create_step_with_twist_bus( ¶ms, &ccs, @@ -293,7 +286,7 @@ fn memory_cross_step_read_consistency() { inc_at_write_addr: vec![F::from_u64(100) - F::from_u64(42)], // 100 - 42 = 58 }; let mem_init = MemInit::Sparse(vec![(0, F::from_u64(42))]); - let (mem_inst, mem_wit) = make_twist_instance(&mem_layout, mem_init, 1); + let (mem_inst, mem_wit) = make_twist_instance(0, &mem_layout, mem_init, 1); steps.push(create_step_with_twist_bus( ¶ms, &ccs, @@ -308,7 +301,7 @@ fn memory_cross_step_read_consistency() { let mut tr_prove = Poseidon2Transcript::new(b"mem-cross-step"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -323,7 +316,7 @@ fn memory_cross_step_read_consistency() { let mut tr_verify = Poseidon2Transcript::new(b"mem-cross-step"); let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -367,7 +360,7 @@ fn memory_read_uninitialized_returns_zero() { }; let mem_init = MemInit::Zero; - let (mem_inst, mem_wit) = make_twist_instance(&mem_layout, mem_init, 1); + let (mem_inst, mem_wit) = make_twist_instance(0, &mem_layout, mem_init, 1); let step_bundle = create_step_with_twist_bus(¶ms, &ccs, &l, 0, vec![(mem_inst, mem_wit, mem_trace)]); let acc_init: Vec> = Vec::new(); @@ -375,7 +368,7 @@ fn memory_read_uninitialized_returns_zero() { let mut tr_prove = Poseidon2Transcript::new(b"mem-uninitialized"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -390,7 +383,7 @@ fn memory_read_uninitialized_returns_zero() { let mut tr_verify = Poseidon2Transcript::new(b"mem-uninitialized"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -437,7 +430,7 @@ fn memory_tamper_read_value_fails() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = make_twist_instance(&mem_layout, mem_init, 1); + let (mem_inst, mem_wit) = make_twist_instance(0, &mem_layout, mem_init, 1); let step_bundle = create_step_with_twist_bus(¶ms, &ccs, &l, 0, vec![(mem_inst, mem_wit, bad_mem_trace)]); let acc_init: Vec> = Vec::new(); @@ -445,7 +438,7 @@ fn memory_tamper_read_value_fails() { let mut tr_prove = Poseidon2Transcript::new(b"mem-tamper-read"); let proof_result = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -462,7 +455,7 @@ fn memory_tamper_read_value_fails() { let mut tr_verify = Poseidon2Transcript::new(b"mem-tamper-read"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let verify_result = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -508,7 +501,7 @@ fn memory_tamper_write_increment_fails() { write_val: vec![F::from_u64(100)], inc_at_write_addr: vec![F::from_u64(10)], // WRONG: should be 58 }; - let (mem_inst, mem_wit) = make_twist_instance(&mem_layout, mem_init, 1); + let (mem_inst, mem_wit) = make_twist_instance(0, &mem_layout, mem_init, 1); let step_bundle = create_step_with_twist_bus(¶ms, &ccs, &l, 0, vec![(mem_inst, mem_wit, bad_mem_trace)]); let acc_init: Vec> = Vec::new(); @@ -516,7 +509,7 @@ fn memory_tamper_write_increment_fails() { let mut tr_prove = Poseidon2Transcript::new(b"mem-tamper-inc"); let proof_result = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -533,7 +526,7 @@ fn memory_tamper_write_increment_fails() { let mut tr_verify = Poseidon2Transcript::new(b"mem-tamper-inc"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let verify_result = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -596,8 +589,8 @@ fn memory_multiple_regions_same_step() { }; let reg_init = MemInit::Sparse(vec![(0, F::from_u64(10))]); - let (ram_inst, ram_wit) = make_twist_instance(&ram_layout, ram_init, 1); - let (reg_inst, reg_wit) = make_twist_instance(®_layout, reg_init, 1); + let (ram_inst, ram_wit) = make_twist_instance(0, &ram_layout, ram_init, 1); + let (reg_inst, reg_wit) = make_twist_instance(1, ®_layout, reg_init, 1); let step_bundle = create_step_with_twist_bus( ¶ms, &ccs, @@ -611,7 +604,7 @@ fn memory_multiple_regions_same_step() { let mut tr_prove = Poseidon2Transcript::new(b"mem-multi-region"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -626,7 +619,7 @@ fn memory_multiple_regions_same_step() { let mut tr_verify = Poseidon2Transcript::new(b"mem-multi-region"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -672,7 +665,7 @@ fn memory_sparse_initialization() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = make_twist_instance(&mem_layout, mem_init, 1); + let (mem_inst, mem_wit) = make_twist_instance(0, &mem_layout, mem_init, 1); let step_bundle = create_step_with_twist_bus(¶ms, &ccs, &l, 0, vec![(mem_inst, mem_wit, mem_trace)]); let acc_init: Vec> = Vec::new(); @@ -680,7 +673,7 @@ fn memory_sparse_initialization() { let mut tr_prove = Poseidon2Transcript::new(b"mem-sparse-init"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -695,7 +688,7 @@ fn memory_sparse_initialization() { let mut tr_verify = Poseidon2Transcript::new(b"mem-sparse-init"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, diff --git a/crates/neo-fold/tests/suites/perf/mod.rs b/crates/neo-fold/tests/suites/perf/mod.rs new file mode 100644 index 00000000..ec884b5a --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/mod.rs @@ -0,0 +1,5 @@ +mod memory_adversarial_tests; +mod prefix_scaling; +mod riscv_trace_ab_perf; +mod riscv_trace_wiring_output_binding_perf; +mod single_addi_metrics_nightstream; diff --git a/crates/neo-fold/tests/suites/perf/nightstream_prefix_scaling_perf.rs b/crates/neo-fold/tests/suites/perf/nightstream_prefix_scaling_perf.rs new file mode 100644 index 00000000..639614c7 --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/nightstream_prefix_scaling_perf.rs @@ -0,0 +1,200 @@ +#![allow(non_snake_case)] + +use std::time::{Duration, Instant}; + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::lookups::{encode_program, BranchCondition, RiscvInstruction, RiscvOpcode}; + +struct ScaleRow { + n_instr: usize, + + ns_step_rows_raw: usize, + ns_step_rows_p2: usize, + ns_cols_p2: usize, + ns_fold_chunks: usize, + ns_rows_total_padded: usize, + ns_prove_time: Duration, + ns_verify_time: Duration, + ns_total_time: Duration, +} + +#[test] +#[ignore = "perf test; run with `cargo test -p neo-fold --test nightstream_prefix_scaling_perf --release -- --ignored --nocapture`"] +fn nightstream_prefix_lengths_1_to_10_and_256() { + // Fixed instruction sequence; we benchmark prefixes of length 1..10 and 256. + // + // Nightstream: execute `n` RV32 instructions and prove them as a single chunk (chunk_size=n), + // so this is one proof per prefix length (no folding per instruction). + let base_sequence = instruction_sequence(); + assert_eq!(base_sequence.len(), 10); + + let mut rows: Vec = Vec::with_capacity(11); + let mut ns: Vec = (1..=10).collect(); + ns.push(256); + + for n in ns { + let ns_program: Vec = (0..n) + .map(|i| base_sequence[i % base_sequence.len()].clone()) + .collect(); + let ns_program_bytes = encode_program(&ns_program); + + let ns_total_start = Instant::now(); + let mut ns_run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &ns_program_bytes) + // IMPORTANT: avoid "fold per instruction". + // Use a single folding chunk that covers the entire prefix so this is one proof for `n` instructions. + .min_trace_len(n) + .chunk_rows(n) + .max_steps(n) + .prove() + .expect("Nightstream prove"); + + let ns_step_rows_raw = ns_run.ccs_num_constraints(); + let ns_cols_raw = ns_run.ccs_num_variables(); + let ns_step_rows_p2 = ns_step_rows_raw.next_power_of_two(); + let ns_cols_p2 = ns_cols_raw.next_power_of_two(); + let ns_fold_chunks = ns_run.fold_count(); + let ns_rows_total_padded = ns_step_rows_p2.saturating_mul(ns_fold_chunks); + + ns_run.verify().expect("Nightstream verify"); + let ns_prove_time = ns_run.prove_duration(); + let ns_verify_time = ns_run + .verify_duration() + .expect("Nightstream verify duration"); + let ns_total_time = ns_total_start.elapsed(); + + rows.push(ScaleRow { + n_instr: n, + + ns_step_rows_raw, + ns_step_rows_p2, + ns_cols_p2, + ns_fold_chunks, + ns_rows_total_padded, + ns_prove_time, + ns_verify_time, + ns_total_time, + }); + } + + println!(); + println!("{:=<105}", ""); + println!("NIGHTSTREAM — PREFIX SCALING (n=1..10, 256)"); + println!("{:=<105}", ""); + println!("Note: times include per-run setup; compare trends (slope) more than absolute intercept on tiny traces."); + println!("Note: rowsTotal = next_pow2(ccs.n) * fold_chunks, cols(p2) = next_pow2(ccs.m)."); + println!(); + + println!("{:-<105}", ""); + println!( + "{:>4} {:>13} {:>10} {:>10} {:>8} {:>9} {:>9} {:>9} {:>9}", + "n", "rows/chunk", "rowsTot", "cols(p2)", "chunks", "prove", "verify", "total", "prove/n", + ); + println!("{:-<105}", ""); + for r in &rows { + let ns_rows_step = format!("{}/{}", r.ns_step_rows_raw, r.ns_step_rows_p2); + println!( + "{:>4} {:>13} {:>10} {:>10} {:>8} {:>9} {:>9} {:>9} {:>9}", + r.n_instr, + ns_rows_step, + r.ns_rows_total_padded, + r.ns_cols_p2, + r.ns_fold_chunks, + fmt_duration(r.ns_prove_time), + fmt_duration(r.ns_verify_time), + fmt_duration(r.ns_total_time), + fmt_duration(div_duration(r.ns_prove_time, r.n_instr)), + ); + } + println!("{:-<105}", ""); + println!(); +} + +fn instruction_sequence() -> Vec { + vec![ + // ADDI x1,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + // ANDI x2,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::And, + rd: 2, + rs1: 0, + imm: 1, + }, + // ORI x3,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 0, + imm: 1, + }, + // XORI x4,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 4, + rs1: 0, + imm: 1, + }, + // SLTI x6,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Slt, + rd: 6, + rs1: 0, + imm: 1, + }, + // SLTIU x7,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Sltu, + rd: 7, + rs1: 0, + imm: 1, + }, + // SLLI x8,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Sll, + rd: 8, + rs1: 0, + imm: 1, + }, + // SRLI x9,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Srl, + rd: 9, + rs1: 0, + imm: 1, + }, + // SRAI x10,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Sra, + rd: 10, + rs1: 0, + imm: 1, + }, + // BNE x0,x0,+8 (not taken) + RiscvInstruction::Branch { + cond: BranchCondition::Ne, + rs1: 0, + rs2: 0, + imm: 8, + }, + ] +} + +fn fmt_duration(d: Duration) -> String { + if d.as_secs_f64() < 1.0 { + format!("{:.3}ms", d.as_secs_f64() * 1000.0) + } else { + format!("{:.3}s", d.as_secs_f64()) + } +} + +fn div_duration(d: Duration, denom: usize) -> Duration { + if denom == 0 { + return Duration::from_secs(0); + } + Duration::from_secs_f64(d.as_secs_f64() / denom as f64) +} diff --git a/crates/neo-fold/tests/suites/perf/prefix_scaling.rs b/crates/neo-fold/tests/suites/perf/prefix_scaling.rs new file mode 100644 index 00000000..757917f4 --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/prefix_scaling.rs @@ -0,0 +1,4 @@ +#[path = "nightstream_prefix_scaling_perf.rs"] +mod nightstream_prefix_scaling_perf; +#[path = "riscv_prefix_scaling_nightstream.rs"] +mod riscv_prefix_scaling_nightstream; diff --git a/crates/neo-fold/tests/suites/perf/riscv_prefix_scaling_nightstream.rs b/crates/neo-fold/tests/suites/perf/riscv_prefix_scaling_nightstream.rs new file mode 100644 index 00000000..db9ad1fc --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/riscv_prefix_scaling_nightstream.rs @@ -0,0 +1,191 @@ +use std::time::{Duration, Instant}; + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::lookups::{encode_program, BranchCondition, RiscvInstruction, RiscvOpcode}; + +struct ScaleRow { + n_instr: usize, + + ns_step_rows_raw: usize, + ns_step_rows_p2: usize, + ns_cols_raw: usize, + ns_cols_p2: usize, + ns_fold_chunks: usize, + ns_rows_total_padded: usize, + ns_prove_time: Duration, + ns_verify_time: Duration, + ns_total_time: Duration, +} + +#[test] +#[ignore = "perf-style test: run with `cargo test -p neo-fold --test riscv_prefix_scaling_nightstream --release -- --ignored --nocapture`"] +fn nightstream_prefix_lengths_1_to_10_and_256_halt_terminated() { + // Fixed instruction sequence; we benchmark prefixes of length 1..10, plus 256. + // + // We always append a HALT so each program terminates. We then prove the whole trace as a + // single chunk by setting chunk_size = trace_len (no folding per instruction). + let base_sequence: Vec = instruction_sequence(); + assert_eq!(base_sequence.len(), 10); + + let mut rows: Vec = Vec::with_capacity(11); + let mut ns: Vec = (1..=10).collect(); + ns.push(256); + + for n in ns { + let mut program: Vec = (0..n) + .map(|i| base_sequence[i % base_sequence.len()].clone()) + .collect(); + program.push(RiscvInstruction::Halt); + let program_bytes = encode_program(&program); + + let trace_len = n + 1; + let ns_total_start = Instant::now(); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .min_trace_len(trace_len) + .chunk_rows(trace_len) + .max_steps(trace_len) + .prove() + .expect("Nightstream prove"); + + let ns_step_rows_raw = run.ccs_num_constraints(); + let ns_cols_raw = run.ccs_num_variables(); + let ns_step_rows_p2 = ns_step_rows_raw.next_power_of_two(); + let ns_cols_p2 = ns_cols_raw.next_power_of_two(); + let ns_fold_chunks = run.fold_count(); + let ns_rows_total_padded = ns_step_rows_p2.saturating_mul(ns_fold_chunks); + + run.verify().expect("Nightstream verify"); + let ns_prove_time = run.prove_duration(); + let ns_verify_time = run.verify_duration().unwrap_or(Duration::ZERO); + let ns_total_time = ns_total_start.elapsed(); + + rows.push(ScaleRow { + n_instr: n, + + ns_step_rows_raw, + ns_step_rows_p2, + ns_cols_raw, + ns_cols_p2, + ns_fold_chunks, + ns_rows_total_padded, + ns_prove_time, + ns_verify_time, + ns_total_time, + }); + } + + println!(); + println!("{:=<110}", ""); + println!("NIGHTSTREAM — SCALING (prefix n=1..10, 256; prove as single chunk)"); + println!("{:=<110}", ""); + println!("Note: times include per-run setup; compare trends more than absolute intercept on tiny traces."); + println!(); + + println!("{:-<110}", ""); + println!( + "{:>4} {:>14} {:>10} {:>10} {:>10} {:>9} {:>9} {:>9} {:>9}", + "n", "NS rows/chunk", "NS rowsTot", "NS cols", "NS cols(p2)", "chunks", "prove", "verify", "total", + ); + println!("{:-<110}", ""); + for r in &rows { + let ns_rows_step = format!("{}/{}", r.ns_step_rows_raw, r.ns_step_rows_p2); + println!( + "{:>4} {:>14} {:>10} {:>10} {:>10} {:>9} {:>9} {:>9} {:>9}", + r.n_instr, + ns_rows_step, + r.ns_rows_total_padded, + r.ns_cols_raw, + r.ns_cols_p2, + r.ns_fold_chunks, + fmt_duration(r.ns_prove_time), + fmt_duration(r.ns_verify_time), + fmt_duration(r.ns_total_time), + ); + } + println!("{:-<110}", ""); + println!(); +} + +fn instruction_sequence() -> Vec { + vec![ + // ADDI x1,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + // ANDI x2,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::And, + rd: 2, + rs1: 0, + imm: 1, + }, + // ORI x3,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 0, + imm: 1, + }, + // XORI x4,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 4, + rs1: 0, + imm: 1, + }, + // SLTI x6,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Slt, + rd: 6, + rs1: 0, + imm: 1, + }, + // SLTIU x7,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Sltu, + rd: 7, + rs1: 0, + imm: 1, + }, + // SLLI x8,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Sll, + rd: 8, + rs1: 0, + imm: 1, + }, + // SRLI x9,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Srl, + rd: 9, + rs1: 0, + imm: 1, + }, + // SRAI x10,x0,1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Sra, + rd: 10, + rs1: 0, + imm: 1, + }, + // BNE x0,x0,+8 (not taken) + RiscvInstruction::Branch { + cond: BranchCondition::Ne, + rs1: 0, + rs2: 0, + imm: 8, + }, + ] +} + +fn fmt_duration(d: Duration) -> String { + if d.as_secs_f64() < 1.0 { + format!("{:.3}ms", d.as_secs_f64() * 1000.0) + } else { + format!("{:.3}s", d.as_secs_f64()) + } +} diff --git a/crates/neo-fold/tests/suites/perf/riscv_trace_ab_perf.rs b/crates/neo-fold/tests/suites/perf/riscv_trace_ab_perf.rs new file mode 100644 index 00000000..8be525cb --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/riscv_trace_ab_perf.rs @@ -0,0 +1,154 @@ +#![allow(non_snake_case)] + +use std::time::Duration; + +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; + +#[derive(Clone, Copy, Debug)] +struct Stats { + min: Duration, + median: Duration, + mean: Duration, + max: Duration, +} + +#[test] +#[ignore = "perf-style test: run with `cargo test -p neo-fold --release --test riscv_trace_ab_perf -- --ignored --nocapture`"] +fn rv32_trace_ab_perf_single_chunk() { + let repeats = env_usize("AB_REPEATS", 64); + let warmups = env_usize("AB_WARMUPS", 1); + let samples = env_usize("AB_SAMPLES", 7); + assert!(repeats > 0, "AB_REPEATS must be > 0"); + assert!(samples > 0, "AB_SAMPLES must be > 0"); + + let mut program = Vec::::new(); + for _ in 0..repeats { + program.extend([ + // x1 = 3; x2 = 4 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 3, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 4, + }, + // x3 = x1 * x2 + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 3, + rs1: 1, + rs2: 2, + }, + // mem[0] = x3; x4 = mem[0] + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 3, + imm: 0, + }, + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 4, + rs1: 0, + imm: 0, + }, + ]); + } + program.push(RiscvInstruction::Halt); + + let program_bytes = encode_program(&program); + let max_steps = program.len(); + + for _ in 0..warmups { + let mut run = run_once(&program_bytes, max_steps).expect("warmup prove"); + run.verify().expect("warmup verify"); + } + + let mut prove_times = Vec::with_capacity(samples); + let mut verify_times = Vec::with_capacity(samples); + for _ in 0..samples { + let mut run = run_once(&program_bytes, max_steps).expect("prove"); + run.verify().expect("verify"); + prove_times.push(run.prove_duration()); + verify_times.push(run.verify_duration().unwrap_or(Duration::ZERO)); + } + + let prove = summarize(&prove_times); + let verify = summarize(&verify_times); + + println!(); + println!("{:=<96}", ""); + println!("RV32 Trace A/B PERF (single chunk, fixed program)"); + println!("{:=<96}", ""); + println!( + "config: repeats={} instructions={} warmups={} samples={}", + repeats, max_steps, warmups, samples + ); + println!("{:-<96}", ""); + println!( + "{:>10} {:>10} {:>10} {:>10} {:>10}", + "phase", "min", "median", "mean", "max" + ); + println!("{:-<96}", ""); + println!( + "{:>10} {:>10} {:>10} {:>10} {:>10}", + "prove", + fmt_duration(prove.min), + fmt_duration(prove.median), + fmt_duration(prove.mean), + fmt_duration(prove.max), + ); + println!( + "{:>10} {:>10} {:>10} {:>10} {:>10}", + "verify", + fmt_duration(verify.min), + fmt_duration(verify.median), + fmt_duration(verify.mean), + fmt_duration(verify.max), + ); + println!("{:-<96}", ""); + println!(); +} + +fn run_once(program_bytes: &[u8], max_steps: usize) -> Result { + Rv32TraceWiring::from_rom(/*program_base=*/ 0, program_bytes) + .xlen(32) + .min_trace_len(max_steps) + .chunk_rows(max_steps) + .max_steps(max_steps) + .shout_auto_minimal() + .prove() +} + +fn summarize(samples: &[Duration]) -> Stats { + assert!(!samples.is_empty()); + let mut v = samples.to_vec(); + v.sort_unstable(); + let min = v[0]; + let max = v[v.len() - 1]; + let median = v[v.len() / 2]; + let mean_secs = v.iter().map(Duration::as_secs_f64).sum::() / v.len() as f64; + let mean = Duration::from_secs_f64(mean_secs); + Stats { min, median, mean, max } +} + +fn fmt_duration(d: Duration) -> String { + if d.as_secs_f64() < 1.0 { + format!("{:.3}ms", d.as_secs_f64() * 1000.0) + } else { + format!("{:.3}s", d.as_secs_f64()) + } +} + +fn env_usize(name: &str, default: usize) -> usize { + match std::env::var(name) { + Ok(v) => v.parse::().unwrap_or(default), + Err(_) => default, + } +} diff --git a/crates/neo-fold/tests/suites/perf/riscv_trace_wiring_output_binding_perf.rs b/crates/neo-fold/tests/suites/perf/riscv_trace_wiring_output_binding_perf.rs new file mode 100644 index 00000000..870967cb --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/riscv_trace_wiring_output_binding_perf.rs @@ -0,0 +1,195 @@ +#![allow(non_snake_case)] + +use std::time::Duration; + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +#[derive(Clone, Copy, Debug)] +enum ClaimMode { + None, + Reg, +} + +#[derive(Clone, Copy, Debug)] +struct Stats { + min: Duration, + median: Duration, + mean: Duration, + max: Duration, +} + +#[test] +#[ignore = "perf-style test: run with `cargo test -p neo-fold --release --test riscv_trace_wiring_output_binding_perf -- --ignored --nocapture`"] +fn rv32_trace_wiring_output_binding_overhead_perf() { + let n_adds = env_usize("TW_N_ADDS", 512); + let samples = env_usize("TW_SAMPLES", 7); + let warmups = env_usize("TW_WARMUPS", 1); + assert!(n_adds > 0, "TW_N_ADDS must be > 0"); + assert!(samples > 0, "TW_SAMPLES must be > 0"); + let expected_x1 = F::from_u64(n_adds as u64); + + let mut program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 1, + imm: 1, + }; + n_adds + ]; + program.push(RiscvInstruction::Halt); + let program_bytes = encode_program(&program); + + let (none, none_shape) = run_samples( + &program_bytes, + n_adds + 1, + expected_x1, + ClaimMode::None, + warmups, + samples, + ); + let (reg, _reg_shape) = run_samples( + &program_bytes, + n_adds + 1, + expected_x1, + ClaimMode::Reg, + warmups, + samples, + ); + + let none_stats = summarize(&none); + let reg_stats = summarize(®); + let median_ratio = ratio(reg_stats.median, none_stats.median); + let mean_ratio = ratio(reg_stats.mean, none_stats.mean); + + println!(); + println!("{:=<96}", ""); + println!("RV32 TRACE WIRING PERF — NO OUTPUT vs REG OUTPUT BINDING"); + println!("{:=<96}", ""); + println!( + "config: n_adds={} trace_len={} warmups={} samples={}", + n_adds, + n_adds + 1, + warmups, + samples + ); + if let Some((ccs_n, ccs_m, trace_len)) = none_shape { + println!("shape: ccs_n={} ccs_m={} trace_len={}", ccs_n, ccs_m, trace_len); + } + println!("{:-<96}", ""); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "mode", "min", "median", "mean", "max" + ); + println!("{:-<96}", ""); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "no-output", + fmt_duration(none_stats.min), + fmt_duration(none_stats.median), + fmt_duration(none_stats.mean), + fmt_duration(none_stats.max), + ); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "reg-output", + fmt_duration(reg_stats.min), + fmt_duration(reg_stats.median), + fmt_duration(reg_stats.mean), + fmt_duration(reg_stats.max), + ); + println!("{:-<96}", ""); + println!( + "ratio reg/no-output: median={:.3}x mean={:.3}x", + median_ratio, mean_ratio + ); + let trace_len = n_adds + 1; + let none_khz = trace_len as f64 / none_stats.median.as_secs_f64() / 1_000.0; + let reg_khz = trace_len as f64 / reg_stats.median.as_secs_f64() / 1_000.0; + println!( + "throughput (median): no-output={:.3} kHz reg-output={:.3} kHz", + none_khz, reg_khz + ); + println!("{:-<96}", ""); + println!(); +} + +fn run_samples( + program_bytes: &[u8], + max_steps: usize, + expected_x1: F, + claim_mode: ClaimMode, + warmups: usize, + samples: usize, +) -> (Vec, Option<(usize, usize, usize)>) { + for _ in 0..warmups { + let mut run = build_runner(program_bytes, max_steps, expected_x1, claim_mode) + .prove() + .expect("warmup prove"); + run.verify().expect("warmup verify"); + } + + let mut out = Vec::with_capacity(samples); + let mut shape: Option<(usize, usize, usize)> = None; + for _ in 0..samples { + let mut run = build_runner(program_bytes, max_steps, expected_x1, claim_mode) + .prove() + .expect("prove"); + run.verify().expect("verify"); + if shape.is_none() { + shape = Some((run.ccs_num_constraints(), run.ccs_num_variables(), run.trace_len())); + } + out.push(run.prove_duration()); + } + (out, shape) +} + +fn build_runner(program_bytes: &[u8], max_steps: usize, expected_x1: F, claim_mode: ClaimMode) -> Rv32TraceWiring { + let runner = Rv32TraceWiring::from_rom(/*program_base=*/ 0, program_bytes) + .min_trace_len(max_steps) + .max_steps(max_steps); + + match claim_mode { + ClaimMode::None => runner, + ClaimMode::Reg => runner.reg_output_claim(/*reg=*/ 1, expected_x1), + } +} + +fn summarize(samples: &[Duration]) -> Stats { + assert!(!samples.is_empty()); + let mut v = samples.to_vec(); + v.sort_unstable(); + + let min = v[0]; + let max = v[v.len() - 1]; + let median = v[v.len() / 2]; + let mean_secs = v.iter().map(Duration::as_secs_f64).sum::() / v.len() as f64; + let mean = Duration::from_secs_f64(mean_secs); + + Stats { min, median, mean, max } +} + +fn ratio(numer: Duration, denom: Duration) -> f64 { + if denom.is_zero() { + return f64::INFINITY; + } + numer.as_secs_f64() / denom.as_secs_f64() +} + +fn fmt_duration(d: Duration) -> String { + if d.as_secs_f64() < 1.0 { + format!("{:.3}ms", d.as_secs_f64() * 1000.0) + } else { + format!("{:.3}s", d.as_secs_f64()) + } +} + +fn env_usize(name: &str, default: usize) -> usize { + match std::env::var(name) { + Ok(v) => v.parse::().unwrap_or(default), + Err(_) => default, + } +} diff --git a/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs b/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs new file mode 100644 index 00000000..4be0923e --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs @@ -0,0 +1,910 @@ +use std::time::{Duration, Instant}; + +use neo_ccs::MeInstance; +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_fold::shard::ShardProof; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, Rv32TraceCcsLayout}; +use neo_memory::riscv::lookups::{encode_program, BranchCondition, RiscvInstruction, RiscvOpcode}; + +#[test] +#[ignore = "perf-style test: run with `cargo test -p neo-fold --release --test perf -- --ignored --nocapture compare_single_mixed_metrics_nightstream_only`"] +fn compare_single_mixed_metrics_nightstream_only() { + let instruction_label = "Mixed sequence (ADD/AND/OR/XOR/SLT/SLTU/SLL/SRL/SRA/BNE)"; + + let ns_program = mixed_instruction_sequence(); + let ns_program_bytes = encode_program(&ns_program); + let ns_chunk_rows = ns_program.len(); + let ns_max_steps = ns_program.len(); + + let ns_total_start = Instant::now(); + let mut ns_run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &ns_program_bytes) + .chunk_rows(ns_chunk_rows) + .max_steps(ns_max_steps) + .prove() + .expect("Nightstream prove"); + + let ns_constraints = ns_run.ccs_num_constraints(); + let ns_witness_cols = ns_run.ccs_num_variables(); + let ns_constraints_padded_pow2 = ns_constraints.next_power_of_two(); + let ns_witness_cols_padded_pow2 = ns_witness_cols.next_power_of_two(); + let ns_fold_count = ns_run.fold_count(); + let ns_trace_len = ns_run.trace_len(); + let ns_shout_tables = ns_run.used_shout_table_ids().len(); + let ns_step0 = ns_run + .steps_public() + .first() + .cloned() + .expect("Nightstream collected steps"); + let ns_m_in = ns_step0.mcs_inst.m_in; + let ns_witness_private = ns_witness_cols.saturating_sub(ns_m_in); + let ns_lut_instances = ns_step0.lut_insts.len(); + let ns_mem_instances = ns_step0.mem_insts.len(); + + ns_run.verify().expect("Nightstream verify"); + let ns_prove_time = ns_run.prove_duration(); + let ns_verify_time = ns_run + .verify_duration() + .expect("Nightstream verify duration"); + let ns_total_duration = ns_total_start.elapsed(); + + println!(); + println!("Instruction under test: {instruction_label}"); + println!(); + println!("**Nightstream (RV32 Trace)**"); + println!( + "- CCS: n={} constraints (padded_pow2_n={}), m={} cols (padded_pow2_m={}) (m_in={} public, w={} private)", + ns_constraints, + ns_constraints_padded_pow2, + ns_witness_cols, + ns_witness_cols_padded_pow2, + ns_m_in, + ns_witness_private + ); + println!( + "- Trace: executed_steps={} (max_steps={}), fold_chunks={} (chunk_rows={})", + ns_trace_len, ns_max_steps, ns_fold_count, ns_chunk_rows + ); + println!( + "- Sidecars: lut_instances={} mem_instances={} shout_tables_used={}", + ns_lut_instances, ns_mem_instances, ns_shout_tables + ); + println!( + "- Time: prove={} verify={} total_end_to_end={}", + fmt_duration(ns_prove_time), + fmt_duration(ns_verify_time), + fmt_duration(ns_total_duration) + ); + println!(); + + println!("{:-<80}", ""); + println!("{:<40} {:>18}", "Metric", "Nightstream"); + println!("{:<40} {:>18}", "", "(RV32 Trace)"); + println!("{:-<80}", ""); + println!("{:<40} {:>18}", "Rows per step (raw)", ns_constraints); + println!( + "{:<40} {:>18}", + "Rows per step (padded pow2)", ns_constraints_padded_pow2 + ); + println!( + "{:<40} {:>18}", + "Total rows in proof (padded)", + ns_constraints_padded_pow2.saturating_mul(ns_fold_count) + ); + println!( + "{:<40} {:>18}", + "Total rows (estimate, unpadded)", + ns_constraints.saturating_mul(ns_trace_len) + ); + println!("{:<40} {:>18}", "Cols / vars (raw)", ns_witness_cols); + println!( + "{:<40} {:>18}", + "Cols / vars (padded pow2)", ns_witness_cols_padded_pow2 + ); + println!("{:<40} {:>18}", "Public inputs (m_in)", ns_m_in); + println!( + "{:<40} {:>18}", + "Trace len (unpadded)", + format!("{} steps", ns_trace_len) + ); + println!("{:<40} {:>18}", "Lookup tables", format!("{} Shout", ns_lut_instances)); + println!("{:<40} {:>18}", "Shout tables used", ns_shout_tables); + println!("{:<40} {:>18}", "Prove time", fmt_duration(ns_prove_time)); + println!("{:<40} {:>18}", "Verify time", fmt_duration(ns_verify_time)); + println!("{:-<80}", ""); +} + +fn fmt_duration(d: Duration) -> String { + if d.as_secs_f64() < 1.0 { + format!("{:.3}ms", d.as_secs_f64() * 1000.0) + } else { + format!("{:.3}s", d.as_secs_f64()) + } +} + +#[derive(Clone, Copy, Debug, Default)] +struct OpeningSurfaceBuckets { + core_ccs: usize, + sidecars: usize, + claim_reduction_linkage: usize, + pcs_open: usize, +} + +impl OpeningSurfaceBuckets { + fn total(self) -> usize { + self.core_ccs + self.sidecars + self.claim_reduction_linkage + self.pcs_open + } +} + +fn sum_y_scalars(claims: &[MeInstance]) -> usize { + claims.iter().map(|me| me.y_scalars.len()).sum() +} + +fn opening_surface_from_shard_proof(proof: &ShardProof) -> OpeningSurfaceBuckets { + let mut buckets = OpeningSurfaceBuckets::default(); + for step in &proof.steps { + buckets.core_ccs += sum_y_scalars(&step.fold.ccs_out); + + buckets.sidecars += sum_y_scalars(&step.mem.val_me_claims); + + buckets.claim_reduction_linkage += sum_y_scalars(&step.mem.wb_me_claims); + buckets.claim_reduction_linkage += sum_y_scalars(&step.mem.wp_me_claims); + buckets.claim_reduction_linkage += step.batched_time.claimed_sums.len(); + + buckets.pcs_open += step.fold.dec_children.len(); + buckets.pcs_open += step + .val_fold + .iter() + .map(|p| p.dec_children.len()) + .sum::(); + buckets.pcs_open += step + .wb_fold + .iter() + .map(|p| p.dec_children.len()) + .sum::(); + buckets.pcs_open += step + .wp_fold + .iter() + .map(|p| p.dec_children.len()) + .sum::(); + } + buckets +} + +fn env_usize(name: &str, default: usize) -> usize { + match std::env::var(name) { + Ok(v) => v.parse::().unwrap_or(default), + Err(_) => default, + } +} + +fn mixed_instruction_sequence() -> Vec { + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::And, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 4, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Slt, + rd: 6, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sltu, + rd: 7, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sll, + rd: 8, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Srl, + rd: 9, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sra, + rd: 10, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Branch { + cond: BranchCondition::Ne, + rs1: 0, + rs2: 0, + imm: 8, + }, + ] +} + +#[test] +#[ignore = "perf-style test: NS_DEBUG_N=256 cargo test -p neo-fold --release --test perf -- --ignored --nocapture debug_trace_single_n_mixed_ops"] +fn debug_trace_single_n_mixed_ops() { + let n = env_usize("NS_DEBUG_N", 256); + let chunk_rows = env_usize("NS_TRACE_CHUNK_ROWS", n + 1); + assert!(n > 0); + assert!(chunk_rows > 0); + + let base = mixed_instruction_sequence(); + let mut program: Vec = (0..n).map(|i| base[i % base.len()].clone()).collect(); + program.push(RiscvInstruction::Halt); + let program_bytes = encode_program(&program); + let steps = n + 1; + + let total_start = Instant::now(); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .min_trace_len(steps) + .max_steps(steps) + .chunk_rows(chunk_rows) + .prove() + .expect("trace prove"); + let prove_time = run.prove_duration(); + run.verify().expect("trace verify"); + let verify_time = run.verify_duration().expect("trace verify duration"); + let total_time = total_start.elapsed(); + let phases = run.prove_phase_durations(); + + println!( + "TRACE n={} chunk_rows={} ccs_n={} ccs_m={} n_p2={} m_p2={} trace_len={} folds={} prove={} verify={} total={} phases(setup={}, chunk_commit={}, fold={})", + n, + chunk_rows, + run.ccs_num_constraints(), + run.ccs_num_variables(), + run.ccs_num_constraints().next_power_of_two(), + run.ccs_num_variables().next_power_of_two(), + run.trace_len(), + run.fold_count(), + fmt_duration(prove_time), + fmt_duration(verify_time), + fmt_duration(total_time), + fmt_duration(phases.setup), + fmt_duration(phases.chunk_build_commit), + fmt_duration(phases.fold_and_prove), + ); + let openings = opening_surface_from_shard_proof(run.proof()); + println!( + "TRACE_OPENINGS core_ccs={} sidecars={} claim_reduction_linkage={} pcs_open={} total={}", + openings.core_ccs, + openings.sidecars, + openings.claim_reduction_linkage, + openings.pcs_open, + openings.total() + ); +} + +#[test] +#[ignore = "perf-style test: NS_DEBUG_N=256 cargo test -p neo-fold --release --test perf -- --ignored --nocapture debug_chunked_single_n_mixed_ops"] +fn debug_chunked_single_n_mixed_ops() { + let n = env_usize("NS_DEBUG_N", 256); + assert!(n > 0); + + let base = mixed_instruction_sequence(); + let mut program: Vec = (0..n).map(|i| base[i % base.len()].clone()).collect(); + program.push(RiscvInstruction::Halt); + let program_bytes = encode_program(&program); + let steps = n + 1; + + let total_start = Instant::now(); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .min_trace_len(steps) + .chunk_rows(steps) + .max_steps(steps) + .prove() + .expect("trace single-chunk prove"); + let prove_time = run.prove_duration(); + run.verify().expect("trace single-chunk verify"); + let verify_time = run + .verify_duration() + .expect("trace single-chunk verify duration"); + let total_time = total_start.elapsed(); + let trace_len = run.trace_len(); + let phases = run.prove_phase_durations(); + + println!( + "TRACE_SINGLE_CHUNK n={} ccs_n={} ccs_m={} n_p2={} m_p2={} trace_len={} folds={} prove={} verify={} total={} phases(setup={}, chunk_commit={}, fold={})", + n, + run.ccs_num_constraints(), + run.ccs_num_variables(), + run.ccs_num_constraints().next_power_of_two(), + run.ccs_num_variables().next_power_of_two(), + trace_len, + run.fold_count(), + fmt_duration(prove_time), + fmt_duration(verify_time), + fmt_duration(total_time), + fmt_duration(phases.setup), + fmt_duration(phases.chunk_build_commit), + fmt_duration(phases.fold_and_prove), + ); + let openings = opening_surface_from_shard_proof(run.proof()); + println!( + "TRACE_SINGLE_CHUNK_OPENINGS core_ccs={} sidecars={} claim_reduction_linkage={} pcs_open={} total={}", + openings.core_ccs, + openings.sidecars, + openings.claim_reduction_linkage, + openings.pcs_open, + openings.total() + ); +} + +#[test] +#[ignore = "perf-style report hook: cargo test -p neo-fold --release --test perf -- --ignored --nocapture debug_trace_core_rows_per_cycle_equiv"] +fn debug_trace_core_rows_per_cycle_equiv() { + let t = env_usize("NS_DEBUG_T", 257); + let layout = Rv32TraceCcsLayout::new(t).expect("trace layout"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace core ccs"); + println!( + "TRACE_CORE t={} trace_width={} core_ccs_n={} rows_per_cycle={:.3}", + t, + layout.trace.cols, + ccs.n, + ccs.n as f64 / t as f64 + ); +} + +#[test] +#[ignore = "W0 snapshot: NS_DEBUG_N=10 cargo test -p neo-fold --release --test perf -- --ignored --nocapture report_track_a_w0_w1_snapshot"] +fn report_track_a_w0_w1_snapshot() { + let n = env_usize("NS_DEBUG_N", 256); + assert!(n > 0); + let chunk_rows = n + 1; + + let base = mixed_instruction_sequence(); + let mut program: Vec = (0..n).map(|i| base[i % base.len()].clone()).collect(); + program.push(RiscvInstruction::Halt); + let program_bytes = encode_program(&program); + let steps = n + 1; + + let total_start = Instant::now(); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .min_trace_len(steps) + .max_steps(steps) + .chunk_rows(chunk_rows) + .prove() + .expect("trace prove"); + let prove_time = run.prove_duration(); + run.verify().expect("trace verify"); + let verify_time = run.verify_duration().expect("trace verify duration"); + let total_time = total_start.elapsed(); + let openings = opening_surface_from_shard_proof(run.proof()); + + let layout = Rv32TraceCcsLayout::new(steps).expect("trace layout"); + let core_ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace core ccs"); + let rows_per_cycle = core_ccs.n as f64 / steps as f64; + + let sep = "=".repeat(80); + let thin_sep = "-".repeat(80); + + println!("\n{sep}"); + println!(" TRACK A CONSTRAINT ARCHITECTURE REPORT (n={steps} steps)"); + println!("{sep}\n"); + + // ── 1. Main CCS Layer ── + println!("1. MAIN CCS LAYER (core glue constraints)"); + println!("{thin_sep}"); + println!(" Trace columns: {}", layout.trace.cols); + println!(" Core CCS rows (n): {}", core_ccs.n); + println!(" Core CCS cols (m): {}", core_ccs.m); + println!(" Rows per cycle: {:.3}", rows_per_cycle); + println!(" Public inputs (m_in): {}", layout.m_in); + println!(); + + let col_names = [ + "one", + "active", + "halted", + "cycle", + "pc_before", + "pc_after", + "instr_word", + "rs1_addr", + "rs1_val", + "rs2_addr", + "rs2_val", + "rd_addr", + "rd_val", + "ram_addr", + "ram_rv", + "ram_wv", + "shout_has_lookup", + "shout_val", + "shout_lhs", + "shout_rhs", + "jalr_drop_bit", + ]; + println!(" Trace columns ({}):", col_names.len()); + for (i, name) in col_names.iter().enumerate() { + println!(" [{i:>2}] {name}"); + } + println!(); + + // ── 2. Shared CPU Bus (Sidecar) Layer ── + println!("2. SHARED CPU BUS LAYER (Shout + Twist bus-tail columns)"); + println!("{thin_sep}"); + let total_ccs_m = run.ccs_num_variables(); + let total_ccs_n = run.ccs_num_constraints(); + let trace_base_m = layout.m_in + layout.trace.cols * steps; + let bus_tail_cols = total_ccs_m.saturating_sub(trace_base_m); + println!(" Total CCS m (with bus): {total_ccs_m}"); + println!(" Total CCS n (with bus): {total_ccs_n}"); + println!( + " Trace base m: {trace_base_m} (m_in={} + {}*{})", + layout.m_in, layout.trace.cols, steps + ); + println!(" Bus-tail columns: {bus_tail_cols}"); + let bus_reserved_rows = total_ccs_n.saturating_sub(core_ccs.n); + println!( + " Bus reserved rows: {bus_reserved_rows} (total_n={total_ccs_n} - core_n={})", + core_ccs.n + ); + println!(); + + let step0 = run + .steps_public() + .into_iter() + .next() + .expect("at least one step"); + let n_lut = step0.lut_insts.len(); + let n_mem = step0.mem_insts.len(); + println!(" Shout instances (LUT): {n_lut}"); + for inst in &step0.lut_insts { + let ell_addr = inst.d * inst.ell; + let bus_cols_per_lane = ell_addr + 2; + println!( + " - table_id={:<10} d={} n_side={} ell={} lanes={} bus_cols={}", + inst.table_id, + inst.d, + inst.n_side, + inst.ell, + inst.lanes, + bus_cols_per_lane * inst.lanes + ); + } + println!(" Twist instances (MEM): {n_mem}"); + for inst in &step0.mem_insts { + let ell_addr = inst.d * inst.ell; + let bus_cols_per_lane = 2 * ell_addr + 5; + println!( + " - mem_id={:<10} d={} n_side={} ell={} lanes={} bus_cols={}", + inst.mem_id, + inst.d, + inst.n_side, + inst.ell, + inst.lanes, + bus_cols_per_lane * inst.lanes + ); + } + println!(); + + // ── 3. Route-A Claims ── + println!("3. ROUTE-A BATCHED TIME CLAIMS"); + println!("{thin_sep}"); + let proof = run.proof(); + let step_proof = &proof.steps[0]; + let bt = &step_proof.batched_time; + println!(" Total batched claims: {}", bt.claimed_sums.len()); + println!(); + + // Group claims by category. + let mut ccs_claims = Vec::new(); + let mut shout_claims = Vec::new(); + let mut twist_claims = Vec::new(); + let mut wb_wp_claims = Vec::new(); + let mut decode_claims = Vec::new(); + let mut width_claims = Vec::new(); + let mut control_claims = Vec::new(); + let mut other_claims = Vec::new(); + + for i in 0..bt.labels.len() { + let label = std::str::from_utf8(bt.labels[i]).unwrap_or(""); + let deg = bt.degree_bounds[i]; + let entry = (label.to_string(), deg); + if label.starts_with("ccs/") { + ccs_claims.push(entry); + } else if label.starts_with("shout/") { + shout_claims.push(entry); + } else if label.starts_with("twist/") { + twist_claims.push(entry); + } else if label.starts_with("wb/") || label.starts_with("wp/") { + wb_wp_claims.push(entry); + } else if label.starts_with("decode/") { + decode_claims.push(entry); + } else if label.starts_with("width/") { + width_claims.push(entry); + } else if label.starts_with("control/") { + control_claims.push(entry); + } else { + other_claims.push(entry); + } + } + + let print_group = |name: &str, claims: &[(String, usize)], aggregate: bool| { + if claims.is_empty() { + return; + } + println!(" {name} ({} claims):", claims.len()); + if aggregate { + // Aggregate by label, show count and degree range. + let mut label_counts: Vec<(String, usize, usize, usize)> = Vec::new(); + for (label, deg) in claims { + if let Some(entry) = label_counts.iter_mut().find(|(l, _, _, _)| l == label) { + entry.1 += 1; + entry.2 = entry.2.min(*deg); + entry.3 = entry.3.max(*deg); + } else { + label_counts.push((label.clone(), 1, *deg, *deg)); + } + } + for (label, count, deg_min, deg_max) in &label_counts { + if deg_min == deg_max { + println!(" - {label:<40} x{count:<4} degree_bound={deg_min}"); + } else { + println!(" - {label:<40} x{count:<4} degree_bound={deg_min}..{deg_max}"); + } + } + } else { + for (label, deg) in claims { + println!(" - {label:<40} degree_bound={deg}"); + } + } + }; + + print_group("CCS (main constraint satisfaction)", &ccs_claims, false); + print_group("Shout (lookup argument)", &shout_claims, true); + print_group("Twist (memory argument)", &twist_claims, true); + print_group("WB/WP (booleanity + quiescence)", &wb_wp_claims, false); + print_group("Decode stage (lookup-backed decode)", &decode_claims, false); + print_group("Width stage (lookup-backed width)", &width_claims, false); + print_group("Control stage (branch/jump/writeback)", &control_claims, false); + print_group("Other", &other_claims, false); + println!(); + + // ── 4. Opening Surface ── + println!("4. OPENING SURFACE"); + println!("{thin_sep}"); + println!(" Core CCS: {}", openings.core_ccs); + println!(" Sidecars: {}", openings.sidecars); + println!(" Claim reduction/linkage: {}", openings.claim_reduction_linkage); + println!(" PCS open: {}", openings.pcs_open); + println!(" Total: {}", openings.total()); + println!(); + + // ── 5. Fold Lanes ── + println!("5. FOLD LANES"); + println!("{thin_sep}"); + println!(" Main fold (ccs_out): {} ME claims", step_proof.fold.ccs_out.len()); + println!( + " Main fold (dec children):{} DEC children", + step_proof.fold.dec_children.len() + ); + let val_count: usize = step_proof + .val_fold + .iter() + .map(|v| v.dec_children.len()) + .sum(); + println!( + " Val fold lanes: {} (dec children={})", + step_proof.val_fold.len(), + val_count + ); + let wb_count: usize = step_proof + .wb_fold + .iter() + .map(|w| w.dec_children.len()) + .sum(); + println!( + " WB fold lanes: {} (dec children={})", + step_proof.wb_fold.len(), + wb_count + ); + let wp_count: usize = step_proof + .wp_fold + .iter() + .map(|w| w.dec_children.len()) + .sum(); + println!( + " WP fold lanes: {} (dec children={})", + step_proof.wp_fold.len(), + wp_count + ); + println!(); + + // ── 6. ME Claims (Sidecar Proofs) ── + println!("6. MEMORY SIDECAR ME CLAIMS"); + println!("{thin_sep}"); + let mem = &step_proof.mem; + println!(" Val ME @ r_val: {} claims", mem.val_me_claims.len()); + println!(" WB ME claims: {} claims", mem.wb_me_claims.len()); + println!(" WP ME claims: {} claims", mem.wp_me_claims.len()); + println!(); + + // ── 7. Used Sets ── + println!("7. USED SETS (dynamic instantiation)"); + println!("{thin_sep}"); + println!(" Memory IDs (S_memory): {:?}", run.used_memory_ids()); + println!(" Shout table IDs (S_lookup): {:?}", run.used_shout_table_ids()); + println!(); + + // ── 8. Timing ── + println!("8. TIMING"); + println!("{thin_sep}"); + println!(" Prove: {}", fmt_duration(prove_time)); + println!(" Verify: {}", fmt_duration(verify_time)); + println!(" Total end-to-end: {}", fmt_duration(total_time)); + let phases = run.prove_phase_durations(); + println!(" Phase: setup {}", fmt_duration(phases.setup)); + println!(" Phase: chunk commit {}", fmt_duration(phases.chunk_build_commit)); + println!(" Phase: fold+prove {}", fmt_duration(phases.fold_and_prove)); + println!(); + + // ── 9. Summary ── + println!("9. SUMMARY"); + println!("{sep}"); + println!(" {:<36} {:>10}", "Main trace columns", layout.trace.cols); + println!(" {:<36} {:>10}", "Bus-tail columns", bus_tail_cols); + println!(" {:<36} {:>10}", "Core CCS rows", core_ccs.n); + println!(" {:<36} {:>10}", "Bus reserved rows", bus_reserved_rows); + println!(" {:<36} {:>10}", "Total CCS rows (n)", total_ccs_n); + println!(" {:<36} {:>10}", "Total CCS cols (m)", total_ccs_m); + println!(" {:<36} {:>10}", "Route-A batched claims", bt.claimed_sums.len()); + println!(" {:<36} {:>10}", " of which: CCS", ccs_claims.len()); + println!(" {:<36} {:>10}", " of which: Shout", shout_claims.len()); + println!(" {:<36} {:>10}", " of which: Twist", twist_claims.len()); + println!(" {:<36} {:>10}", " of which: WB/WP", wb_wp_claims.len()); + println!(" {:<36} {:>10}", " of which: Decode", decode_claims.len()); + println!(" {:<36} {:>10}", " of which: Width", width_claims.len()); + println!(" {:<36} {:>10}", " of which: Control", control_claims.len()); + println!(" {:<36} {:>10}", "Commit lanes", 1); + println!(" {:<36} {:>10}", "Committed sidecars", 0); + println!("{sep}"); +} + +#[test] +#[ignore = "perf-style test: NS_DEBUG_N=256 cargo test -p neo-fold --release --test perf -- --ignored --nocapture debug_trace_vs_chunked_single_n_mixed_ops"] +fn debug_trace_vs_chunked_single_n_mixed_ops() { + let n = env_usize("NS_DEBUG_N", 256); + let chunk_rows = env_usize("NS_TRACE_CHUNK_ROWS", n + 1); + assert!(n > 0); + assert!(chunk_rows > 0); + let base = mixed_instruction_sequence(); + assert_eq!(base.len(), 10); + + let mut program: Vec = (0..n).map(|i| base[i % base.len()].clone()).collect(); + program.push(RiscvInstruction::Halt); + let program_bytes = encode_program(&program); + let steps = n + 1; + + let chunk_total_start = Instant::now(); + let mut chunk_run = Rv32TraceWiring::from_rom(0, &program_bytes) + .min_trace_len(steps) + .chunk_rows(steps) + .max_steps(steps) + .prove() + .expect("trace single-chunk prove (mixed)"); + let chunk_prove = chunk_run.prove_duration(); + let chunk_phases = chunk_run.prove_phase_durations(); + chunk_run + .verify() + .expect("trace single-chunk verify (mixed)"); + let chunk_verify = chunk_run + .verify_duration() + .expect("trace single-chunk verify duration"); + let chunk_total = chunk_total_start.elapsed(); + + let trace_total_start = Instant::now(); + let trace_res = Rv32TraceWiring::from_rom(0, &program_bytes) + .min_trace_len(steps) + .max_steps(steps) + .chunk_rows(chunk_rows) + .prove(); + match trace_res { + Ok(mut trace_run) => { + let trace_prove = trace_run.prove_duration(); + trace_run.verify().expect("trace verify (mixed)"); + let trace_verify = trace_run.verify_duration().expect("trace verify duration"); + let trace_total = trace_total_start.elapsed(); + let trace_phases = trace_run.prove_phase_durations(); + println!( + "MIXED n={} TRACE(prove={}, verify={}, total={}, n_p2={}, m_p2={}, phases: setup={}, chunk_commit={}, fold={}) TRACE_SINGLE_CHUNK(prove={}, verify={}, total={}, n_p2={}, m_p2={}, phases: setup={}, chunk_commit={}, fold={}) ratio_prove={:.2}x", + n, + fmt_duration(trace_prove), + fmt_duration(trace_verify), + fmt_duration(trace_total), + trace_run.ccs_num_constraints().next_power_of_two(), + trace_run.ccs_num_variables().next_power_of_two(), + fmt_duration(trace_phases.setup), + fmt_duration(trace_phases.chunk_build_commit), + fmt_duration(trace_phases.fold_and_prove), + fmt_duration(chunk_prove), + fmt_duration(chunk_verify), + fmt_duration(chunk_total), + chunk_run.ccs_num_constraints().next_power_of_two(), + chunk_run.ccs_num_variables().next_power_of_two(), + fmt_duration(chunk_phases.setup), + fmt_duration(chunk_phases.chunk_build_commit), + fmt_duration(chunk_phases.fold_and_prove), + trace_prove.as_secs_f64() / chunk_prove.as_secs_f64(), + ); + } + Err(e) => { + println!( + "MIXED n={} TRACE(prove=ERROR:{}) TRACE_SINGLE_CHUNK(prove={}, verify={}, total={}, n_p2={}, m_p2={})", + n, + e, + fmt_duration(chunk_prove), + fmt_duration(chunk_verify), + fmt_duration(chunk_total), + chunk_run.ccs_num_constraints().next_power_of_two(), + chunk_run.ccs_num_variables().next_power_of_two(), + ); + } + } +} + +#[derive(Clone, Copy, Debug)] +struct PerfSample { + end_to_end: Duration, + prove: Duration, + verify: Duration, + setup: Duration, + build_commit: Duration, + fold: Duration, +} + +fn median_duration(values: &[Duration]) -> Duration { + let mut nanos: Vec = values.iter().map(|d| d.as_nanos()).collect(); + nanos.sort_unstable(); + Duration::from_nanos(nanos[nanos.len() / 2] as u64) +} + +fn spread_pct(values: &[Duration], median: Duration) -> f64 { + if values.is_empty() || median.is_zero() { + return 0.0; + } + let med = median.as_secs_f64(); + let max_abs = values + .iter() + .map(|v| (v.as_secs_f64() - med).abs()) + .fold(0.0f64, f64::max); + (max_abs / med) * 100.0 +} + +fn build_mixed_program(n: usize) -> Vec { + let base = mixed_instruction_sequence(); + let mut program: Vec = (0..n).map(|i| base[i % base.len()].clone()).collect(); + program.push(RiscvInstruction::Halt); + program +} + +fn run_trace_sample(program: &[RiscvInstruction]) -> PerfSample { + let steps = program.len(); + let program_bytes = encode_program(program); + let total_start = Instant::now(); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .min_trace_len(steps) + .max_steps(steps) + .chunk_rows(steps) + .prove() + .expect("trace prove"); + let prove = run.prove_duration(); + let phases = run.prove_phase_durations(); + run.verify().expect("trace verify"); + let verify = run.verify_duration().expect("trace verify duration"); + PerfSample { + end_to_end: total_start.elapsed(), + prove, + verify, + setup: phases.setup, + build_commit: phases.chunk_build_commit, + fold: phases.fold_and_prove, + } +} + +fn run_single_chunk_trace_sample(program: &[RiscvInstruction]) -> PerfSample { + let steps = program.len(); + let program_bytes = encode_program(program); + let total_start = Instant::now(); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .min_trace_len(steps) + .chunk_rows(steps) + .max_steps(steps) + .prove() + .expect("trace single-chunk prove"); + let prove = run.prove_duration(); + let phases = run.prove_phase_durations(); + run.verify().expect("trace single-chunk verify"); + let verify = run + .verify_duration() + .expect("trace single-chunk verify duration"); + PerfSample { + end_to_end: total_start.elapsed(), + prove, + verify, + setup: phases.setup, + build_commit: phases.chunk_build_commit, + fold: phases.fold_and_prove, + } +} + +fn report_samples(label: &str, samples: &[PerfSample]) { + let end_vals: Vec = samples.iter().map(|s| s.end_to_end).collect(); + let prove_vals: Vec = samples.iter().map(|s| s.prove).collect(); + let verify_vals: Vec = samples.iter().map(|s| s.verify).collect(); + let setup_vals: Vec = samples.iter().map(|s| s.setup).collect(); + let build_vals: Vec = samples.iter().map(|s| s.build_commit).collect(); + let fold_vals: Vec = samples.iter().map(|s| s.fold).collect(); + let prove_window_vals: Vec = samples + .iter() + .map(|s| s.setup + s.build_commit + s.fold) + .collect(); + + let end_med = median_duration(&end_vals); + let prove_med = median_duration(&prove_vals); + let verify_med = median_duration(&verify_vals); + let setup_med = median_duration(&setup_vals); + let build_med = median_duration(&build_vals); + let fold_med = median_duration(&fold_vals); + let prove_window_med = median_duration(&prove_window_vals); + + println!( + "{}: median(end={}, prove_api={}, prove_window={}, verify={}, setup={}, build_commit={}, fold={}) spread(end={:.2}%, prove_window={:.2}%, fold={:.2}%)", + label, + fmt_duration(end_med), + fmt_duration(prove_med), + fmt_duration(prove_window_med), + fmt_duration(verify_med), + fmt_duration(setup_med), + fmt_duration(build_med), + fmt_duration(fold_med), + spread_pct(&end_vals, end_med), + spread_pct(&prove_window_vals, prove_window_med), + spread_pct(&fold_vals, fold_med), + ); +} + +#[test] +#[ignore = "perf baseline report: cargo test -p neo-fold --release --test perf -- --ignored --nocapture report_trace_vs_chunked_medians"] +fn report_trace_vs_chunked_medians() { + const RUNS: usize = 5; + let cases = [ + ("mixed", 10usize, build_mixed_program(10)), + ("mixed", 256usize, build_mixed_program(256)), + ]; + + for (kind, n, program) in cases { + let mut trace_samples = Vec::with_capacity(RUNS); + let mut chunked_samples = Vec::with_capacity(RUNS); + for _ in 0..RUNS { + trace_samples.push(run_trace_sample(&program)); + chunked_samples.push(run_single_chunk_trace_sample(&program)); + } + println!("CASE kind={} n={} runs={}", kind, n, RUNS); + report_samples("TRACE", &trace_samples); + report_samples("TRACE_SINGLE_CHUNK", &chunked_samples); + } +} diff --git a/crates/neo-fold/tests/suites/redteam/mod.rs b/crates/neo-fold/tests/suites/redteam/mod.rs new file mode 100644 index 00000000..ab8ff3fb --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam/mod.rs @@ -0,0 +1 @@ +mod riscv_verifier_gaps; diff --git a/crates/neo-fold/tests/suites/redteam/riscv_verifier_gaps.rs b/crates/neo-fold/tests/suites/redteam/riscv_verifier_gaps.rs new file mode 100644 index 00000000..108fb411 --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam/riscv_verifier_gaps.rs @@ -0,0 +1,97 @@ +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::{F, K}; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn addi_halt_program_bytes(imm: i32) -> Vec { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm, + }, + RiscvInstruction::Halt, + ]; + encode_program(&program) +} + +fn addi_sw_halt_program_bytes(value: i32, addr: i32) -> Vec { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: value, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: addr, + }, + RiscvInstruction::Halt, + ]; + encode_program(&program) +} + +#[test] +fn redteam_output_claim_path_rejects_tampered_proof() { + let program_bytes = addi_sw_halt_program_bytes(/*value=*/ 42, /*addr=*/ 0x100); + let steps = 4usize; + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .output(/*output_addr=*/ 0x100, /*expected_output=*/ F::from_u64(42)) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut bad_proof = run.proof().clone(); + let mut tampered = false; + for step in &mut bad_proof.steps { + for claim in &mut step.mem.val_me_claims { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + tampered = true; + break; + } + } + if tampered { + break; + } + } + assert!(tampered, "expected at least one scalar to tamper"); + assert!( + run.verify_proof(&bad_proof).is_err(), + "tampered proof should fail full verification" + ); +} + +#[test] +fn redteam_verifier_rejects_spliced_proofs_across_runs() { + let program_bytes_a = addi_halt_program_bytes(/*imm=*/ 7); + let steps = 4usize; + let mut run_a = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes_a) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove a"); + run_a.verify().expect("verify a"); + + let program_bytes_b = addi_halt_program_bytes(/*imm=*/ 8); + let mut run_b = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes_b) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove b"); + run_b.verify().expect("verify b"); + + assert!( + run_a.verify_proof(run_b.proof()).is_err(), + "spliced proof across runs must not verify" + ); +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/helpers.rs b/crates/neo-fold/tests/suites/redteam_riscv/helpers.rs new file mode 100644 index 00000000..cc7fe398 --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/helpers.rs @@ -0,0 +1,75 @@ +use neo_ajtai::{AjtaiSModule, Commitment as Cmt}; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_math::K; +use neo_memory::ajtai::encode_vector_balanced_to_mat; +use neo_memory::witness::StepWitnessBundle; +use neo_params::NeoParams; +use p3_goldilocks::Goldilocks as F; + +pub type StepWit = StepWitnessBundle; + +pub fn collect_mcs(instances: &[StepWit]) -> (Vec>, Vec>) { + let mut insts = Vec::with_capacity(instances.len()); + let mut wits = Vec::with_capacity(instances.len()); + for step in instances { + let (inst, wit) = &step.mcs; + insts.push(inst.clone()); + wits.push(wit.clone()); + } + (insts, wits) +} + +pub fn mcs_recommit_step_after_private_tamper( + params: &NeoParams, + committer: &AjtaiSModule, + mcs_inst: &mut McsInstance, + mcs_wit: &mut McsWitness, + idx_to_tamper: usize, + delta: F, +) { + let m_in = mcs_inst.m_in; + assert!( + idx_to_tamper >= m_in, + "expected idx_to_tamper to be in the private witness region (idx={idx_to_tamper}, m_in={m_in})" + ); + + let mut z = Vec::with_capacity(m_in + mcs_wit.w.len()); + z.extend_from_slice(&mcs_inst.x); + z.extend_from_slice(&mcs_wit.w); + assert!( + idx_to_tamper < z.len(), + "idx_to_tamper out of range: idx={idx_to_tamper} len={}", + z.len() + ); + + let x_before = mcs_inst.x.clone(); + z[idx_to_tamper] += delta; + assert_eq!( + &z[..m_in], + &x_before[..], + "tamper helper must not modify public x region" + ); + + mcs_wit.w = z[m_in..].to_vec(); + mcs_wit.Z = encode_vector_balanced_to_mat(params, &z); + mcs_inst.c = committer.commit(&mcs_wit.Z); +} + +pub fn step_bundle_recommit_after_private_tamper( + params: &NeoParams, + committer: &AjtaiSModule, + step: &mut StepWit, + idx_to_tamper: usize, + delta: F, +) { + let (ref mut inst, ref mut wit) = step.mcs; + mcs_recommit_step_after_private_tamper(params, committer, inst, wit, idx_to_tamper, delta); +} + +pub fn assert_prove_or_verify_fails(res: Result, label: &str) { + match res { + Ok(true) => panic!("{label}: unexpectedly verified"), + Ok(false) | Err(_) => {} + } +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/mod.rs b/crates/neo-fold/tests/suites/redteam_riscv/mod.rs new file mode 100644 index 00000000..c12f028a --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/mod.rs @@ -0,0 +1,8 @@ +mod riscv_bus_binding_redteam; +mod riscv_decode_malicious_witness_redteam; +mod riscv_decode_plumbing_linkage; +mod riscv_main_proof_redteam; +mod riscv_semantics_malicious_witness_redteam; +mod riscv_semantics_sidecar_linkage; +mod riscv_twist_shout_redteam; +mod rv32m_sidecar_linkage; diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_bus_binding_redteam.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_bus_binding_redteam.rs new file mode 100644 index 00000000..5c2693a7 --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_bus_binding_redteam.rs @@ -0,0 +1,92 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::shard::ShardProof; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn prove_run(program: Vec, max_steps: usize) -> Rv32TraceWiringRun { + let steps = max_steps; + let program_bytes = encode_program(&program); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + run +} + +fn tamper_any_claim_scalar(proof: &mut ShardProof) { + for step in &mut proof.steps { + for claims in [ + &mut step.fold.ccs_out, + &mut step.mem.val_me_claims, + &mut step.mem.wb_me_claims, + &mut step.mem.wp_me_claims, + ] { + for claim in claims.iter_mut() { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + return; + } + } + } + } + panic!("expected at least one claim scalar to tamper"); +} + +#[test] +fn rv32_trace_cpu_vs_bus_twist_rv_mismatch_must_fail() { + // Program: LW x1, 0(x0); HALT, with RAM[0]=7. + let program = vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let steps = 2usize; + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .ram_init_u32(/*addr=*/ 0, /*value=*/ 7) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut bad_proof = run.proof().clone(); + tamper_any_claim_scalar(&mut bad_proof); + assert!( + run.verify_proof(&bad_proof).is_err(), + "tampered bus/twist binding must not verify" + ); +} + +#[test] +fn rv32_trace_cpu_vs_bus_shout_val_mismatch_must_fail() { + // Program: ADDI x1, x0, 1; HALT (forces an ADD shout lookup). + let run = prove_run( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ], + /*max_steps=*/ 2, + ); + + let mut bad_proof = run.proof().clone(); + tamper_any_claim_scalar(&mut bad_proof); + assert!( + run.verify_proof(&bad_proof).is_err(), + "tampered bus/shout binding must not verify" + ); +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_decode_malicious_witness_redteam.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_decode_malicious_witness_redteam.rs new file mode 100644 index 00000000..5711a56e --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_decode_malicious_witness_redteam.rs @@ -0,0 +1,60 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn prove_run_addi_halt(imm: i32) -> Rv32TraceWiringRun { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let steps = 2usize; + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + run +} + +fn tamper_wp_scalar(run: &Rv32TraceWiringRun) { + let mut bad_proof = run.proof().clone(); + let mut tampered = false; + for step in &mut bad_proof.steps { + for claim in &mut step.mem.wp_me_claims { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + tampered = true; + break; + } + } + if tampered { + break; + } + } + assert!(tampered, "expected at least one wp claim scalar to tamper"); + assert!( + run.verify_proof(&bad_proof).is_err(), + "decode-related malicious tamper must not verify" + ); +} + +#[test] +fn rv32_trace_decode_malicious_imm_i_must_fail() { + let run = prove_run_addi_halt(/*imm=*/ 1); + tamper_wp_scalar(&run); +} + +#[test] +fn rv32_trace_decode_malicious_rd_field_must_fail() { + let run = prove_run_addi_halt(/*imm=*/ 2); + tamper_wp_scalar(&run); +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_decode_plumbing_linkage.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_decode_plumbing_linkage.rs new file mode 100644 index 00000000..0780d582 --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_decode_plumbing_linkage.rs @@ -0,0 +1,60 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::shard::ShardProof; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn prove_run_addi_halt(imm: i32) -> Rv32TraceWiringRun { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let steps = 2usize; + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + run +} + +fn tamper_decode_related_scalar(proof: &mut ShardProof) { + for step in &mut proof.steps { + for claim in &mut step.mem.wp_me_claims { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + return; + } + } + } + panic!("expected at least one decode-related scalar in wp claims"); +} + +#[test] +fn rv32_trace_decode_plumbing_tampered_scalar_must_not_verify() { + let run = prove_run_addi_halt(/*imm=*/ 1); + let mut bad_proof = run.proof().clone(); + tamper_decode_related_scalar(&mut bad_proof); + assert!( + run.verify_proof(&bad_proof).is_err(), + "decode-related tamper must not verify" + ); +} + +#[test] +fn rv32_trace_decode_plumbing_splicing_across_runs_must_fail() { + let run_a = prove_run_addi_halt(/*imm=*/ 1); + let run_b = prove_run_addi_halt(/*imm=*/ 2); + assert!( + run_a.verify_proof(run_b.proof()).is_err(), + "spliced decode commitments must not verify" + ); +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_main_proof_redteam.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_main_proof_redteam.rs new file mode 100644 index 00000000..e521aeb4 --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_main_proof_redteam.rs @@ -0,0 +1,106 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::shard::ShardProof; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn addi_halt_program_bytes(imm: i32) -> Vec { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm, + }, + RiscvInstruction::Halt, + ]; + encode_program(&program) +} + +fn prove_run(program_bytes: &[u8], max_steps: usize) -> Rv32TraceWiringRun { + let steps = max_steps; + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + run +} + +fn tamper_any_claim_scalar(proof: &mut ShardProof) { + for step in &mut proof.steps { + for claims in [ + &mut step.fold.ccs_out, + &mut step.mem.val_me_claims, + &mut step.mem.wb_me_claims, + &mut step.mem.wp_me_claims, + ] { + for claim in claims.iter_mut() { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + return; + } + } + } + } + panic!("expected at least one claim scalar to tamper"); +} + +#[test] +fn rv32_trace_main_proof_truncated_steps_must_fail() { + let program_bytes = addi_halt_program_bytes(/*imm=*/ 1); + let run = prove_run(&program_bytes, /*max_steps=*/ 2); + + let mut bad_proof = run.proof().clone(); + bad_proof.steps.clear(); + assert!( + run.verify_proof(&bad_proof).is_err(), + "truncated main proof must not verify" + ); +} + +#[test] +fn rv32_trace_main_proof_tamper_claim_must_fail() { + let program_bytes = addi_halt_program_bytes(/*imm=*/ 1); + let run = prove_run(&program_bytes, /*max_steps=*/ 2); + + let mut bad_proof = run.proof().clone(); + tamper_any_claim_scalar(&mut bad_proof); + assert!( + run.verify_proof(&bad_proof).is_err(), + "tampered main proof must not verify" + ); +} + +#[test] +fn rv32_trace_main_proof_step_reordering_must_fail() { + let program_bytes = addi_halt_program_bytes(/*imm=*/ 1); + let run = prove_run(&program_bytes, /*max_steps=*/ 2); + + let mut bad_proof = run.proof().clone(); + if bad_proof.steps.len() >= 2 { + bad_proof.steps.swap(0, 1); + } else { + tamper_any_claim_scalar(&mut bad_proof); + } + assert!( + run.verify_proof(&bad_proof).is_err(), + "reordered proof steps must not verify" + ); +} + +#[test] +fn rv32_trace_main_proof_splicing_across_runs_must_fail() { + let program_bytes_a = addi_halt_program_bytes(/*imm=*/ 1); + let run_a = prove_run(&program_bytes_a, /*max_steps=*/ 2); + + let program_bytes_b = addi_halt_program_bytes(/*imm=*/ 2); + let run_b = prove_run(&program_bytes_b, /*max_steps=*/ 2); + + assert!( + run_a.verify_proof(run_b.proof()).is_err(), + "splicing proof across runs must not verify" + ); +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_semantics_malicious_witness_redteam.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_semantics_malicious_witness_redteam.rs new file mode 100644 index 00000000..21606a2e --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_semantics_malicious_witness_redteam.rs @@ -0,0 +1,115 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn prove_run(program: Vec, max_steps: usize) -> Rv32TraceWiringRun { + let steps = max_steps; + let program_bytes = encode_program(&program); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + run +} + +fn tamper_val_scalar(run: &Rv32TraceWiringRun) { + let mut bad_proof = run.proof().clone(); + let mut tampered = false; + for step in &mut bad_proof.steps { + for claim in &mut step.mem.val_me_claims { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + tampered = true; + break; + } + } + if tampered { + break; + } + } + assert!(tampered, "expected at least one val claim scalar to tamper"); + assert!( + run.verify_proof(&bad_proof).is_err(), + "semantics-related malicious tamper must not verify" + ); +} + +#[test] +fn rv32_trace_semantics_malicious_alu_out_must_fail() { + let run = prove_run( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ], + /*max_steps=*/ 2, + ); + tamper_val_scalar(&run); +} + +#[test] +fn rv32_trace_semantics_malicious_eff_addr_must_fail() { + let program = vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let steps = 2usize; + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .ram_init_u32(/*addr=*/ 0, /*value=*/ 7) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + tamper_val_scalar(&run); +} + +#[test] +fn rv32_trace_semantics_malicious_ram_wv_must_fail() { + let run = prove_run( + vec![ + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ], + /*max_steps=*/ 2, + ); + tamper_val_scalar(&run); +} + +#[test] +fn rv32_trace_semantics_malicious_br_taken_must_fail() { + let run = prove_run( + vec![ + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 0, + rs2: 0, + imm: 8, + }, + RiscvInstruction::Nop, + RiscvInstruction::Halt, + ], + /*max_steps=*/ 2, + ); + tamper_val_scalar(&run); +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_semantics_sidecar_linkage.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_semantics_sidecar_linkage.rs new file mode 100644 index 00000000..68ec3ca0 --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_semantics_sidecar_linkage.rs @@ -0,0 +1,60 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::shard::ShardProof; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn prove_run_addi_halt(imm: i32) -> Rv32TraceWiringRun { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let steps = 2usize; + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + run +} + +fn tamper_semantics_related_scalar(proof: &mut ShardProof) { + for step in &mut proof.steps { + for claim in &mut step.mem.val_me_claims { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + return; + } + } + } + panic!("expected at least one semantics-related scalar in val claims"); +} + +#[test] +fn rv32_trace_semantics_tampered_scalar_must_not_verify() { + let run = prove_run_addi_halt(/*imm=*/ 1); + let mut bad_proof = run.proof().clone(); + tamper_semantics_related_scalar(&mut bad_proof); + assert!( + run.verify_proof(&bad_proof).is_err(), + "semantics-related tamper must not verify" + ); +} + +#[test] +fn rv32_trace_semantics_splicing_across_runs_must_fail() { + let run_a = prove_run_addi_halt(/*imm=*/ 1); + let run_b = prove_run_addi_halt(/*imm=*/ 2); + assert!( + run_a.verify_proof(run_b.proof()).is_err(), + "spliced semantics commitments must not verify" + ); +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_twist_shout_redteam.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_twist_shout_redteam.rs new file mode 100644 index 00000000..84dffdc3 --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_twist_shout_redteam.rs @@ -0,0 +1,165 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::shard::ShardProof; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn prove_run(program: Vec, max_steps: usize) -> Rv32TraceWiringRun { + let steps = max_steps; + let program_bytes = encode_program(&program); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + run +} + +fn tamper_any_claim_scalar(proof: &mut ShardProof) { + for step in &mut proof.steps { + for claims in [ + &mut step.fold.ccs_out, + &mut step.mem.val_me_claims, + &mut step.mem.wb_me_claims, + &mut step.mem.wp_me_claims, + ] { + for claim in claims.iter_mut() { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + return; + } + } + } + } + panic!("expected at least one claim scalar to tamper"); +} + +#[test] +fn rv32_trace_twist_claim_tamper_must_fail() { + let run = prove_run( + vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ], + /*max_steps=*/ 2, + ); + + let mut bad_proof = run.proof().clone(); + tamper_any_claim_scalar(&mut bad_proof); + assert!( + run.verify_proof(&bad_proof).is_err(), + "tampered twist proof must not verify" + ); +} + +#[test] +fn rv32_trace_shout_addr_pre_tamper_must_fail() { + let run = prove_run( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ], + /*max_steps=*/ 2, + ); + + let mut bad_proof = run.proof().clone(); + let step = bad_proof + .steps + .iter_mut() + .find(|s| { + s.mem + .shout_addr_pre + .groups + .iter() + .any(|g| !g.active_lanes.is_empty()) + }) + .expect("expected an active shout addr-pre group"); + let group = step + .mem + .shout_addr_pre + .groups + .iter_mut() + .find(|g| !g.active_lanes.is_empty()) + .expect("expected non-empty active lanes"); + group.active_lanes.clear(); + group.round_polys.clear(); + assert!( + run.verify_proof(&bad_proof).is_err(), + "tampered shout addr-pre proof must not verify" + ); +} + +#[test] +fn rv32_trace_proof_step_reordering_must_fail() { + let run = prove_run( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 2, + rs1: 1, + imm: 3, + }, + RiscvInstruction::Halt, + ], + /*max_steps=*/ 3, + ); + + let mut bad_proof = run.proof().clone(); + if bad_proof.steps.len() >= 2 { + bad_proof.steps.swap(0, 1); + } else { + tamper_any_claim_scalar(&mut bad_proof); + } + assert!( + run.verify_proof(&bad_proof).is_err(), + "reordered proof steps must not verify" + ); +} + +#[test] +fn rv32_trace_ram_init_statement_tamper_must_fail() { + let program = vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let steps = 2usize; + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .ram_init_u32(/*addr=*/ 0, /*value=*/ 7) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut bad_proof = run.proof().clone(); + tamper_any_claim_scalar(&mut bad_proof); + assert!( + run.verify_proof(&bad_proof).is_err(), + "tampered RAM-bound proof must not verify" + ); +} diff --git a/crates/neo-fold/tests/suites/redteam_riscv/rv32m_sidecar_linkage.rs b/crates/neo-fold/tests/suites/redteam_riscv/rv32m_sidecar_linkage.rs new file mode 100644 index 00000000..ad13c40b --- /dev/null +++ b/crates/neo-fold/tests/suites/redteam_riscv/rv32m_sidecar_linkage.rs @@ -0,0 +1,48 @@ +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn rv32_trace_claims_are_bound_to_main_commitment() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let steps = 2usize; + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(steps) + .min_trace_len(steps) + .max_steps(steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut bad_proof = run.proof().clone(); + let mut tampered = false; + for step in &mut bad_proof.steps { + for claim in &mut step.mem.val_me_claims { + if let Some(first) = claim.y_scalars.first_mut() { + *first += K::ONE; + tampered = true; + break; + } + } + if tampered { + break; + } + } + assert!(tampered, "expected at least one claim scalar to tamper"); + assert!( + run.verify_proof(&bad_proof).is_err(), + "tampered trace claims must not verify" + ); +} diff --git a/crates/neo-fold/tests/suites/regression/ccs_builder_shape.rs b/crates/neo-fold/tests/suites/regression/ccs_builder_shape.rs new file mode 100644 index 00000000..58e5bc01 --- /dev/null +++ b/crates/neo-fold/tests/suites/regression/ccs_builder_shape.rs @@ -0,0 +1,14 @@ +use neo_fold::session::CcsBuilder; +use neo_math::F; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn ccs_builder_square_does_not_insert_identity_matrix() { + let mut cs = CcsBuilder::::new(1, 0).expect("CcsBuilder::new"); + cs.r1cs_terms([(0, F::from_u64(2))], [(0, F::ONE)], [(0, F::ONE)]); + + // n == m == 1 (square), but builder should keep 3-matrix embedding. + let ccs = cs.build_rect(1, 0).expect("build_rect"); + assert_eq!(ccs.t(), 3, "square build_rect must not auto-insert identity matrix"); + assert!(!ccs.matrices[0].is_identity(), "M0 should be A, not identity"); +} diff --git a/crates/neo-fold/tests/suites/regression/mod.rs b/crates/neo-fold/tests/suites/regression/mod.rs new file mode 100644 index 00000000..903fba60 --- /dev/null +++ b/crates/neo-fold/tests/suites/regression/mod.rs @@ -0,0 +1,2 @@ +mod ccs_builder_shape; +mod test_regression; diff --git a/crates/neo-fold/tests/test_regression.rs b/crates/neo-fold/tests/suites/regression/test_regression.rs similarity index 93% rename from crates/neo-fold/tests/test_regression.rs rename to crates/neo-fold/tests/suites/regression/test_regression.rs index d58ad0a9..c0c60897 100644 --- a/crates/neo-fold/tests/test_regression.rs +++ b/crates/neo-fold/tests/suites/regression/test_regression.rs @@ -69,8 +69,8 @@ fn test_regression_optimized_all_public_inputs() { #[test] fn test_regression_optimized_normalizes_identity_first() { - // Regression: session should accept a square CCS that is not identity-first - // by calling `ensure_identity_first()` internally. + // Regression: session should accept a square CCS that uses plain 3-matrix + // R1CS embedding (no identity-first matrix). let n_constraints = 3usize; let n_vars = 3usize; @@ -108,5 +108,5 @@ fn test_regression_optimized_normalizes_identity_first() { let public_mcss = session.mcss_public(); let ok = session.verify(&ccs, &public_mcss, &run).expect("verify"); - assert!(ok, "verification should pass after identity-first normalization"); + assert!(ok, "verification should pass for non-identity-first 3-matrix CCS"); } diff --git a/crates/neo-fold/tests/suites/rv32m/mod.rs b/crates/neo-fold/tests/suites/rv32m/mod.rs new file mode 100644 index 00000000..4518388b --- /dev/null +++ b/crates/neo-fold/tests/suites/rv32m/mod.rs @@ -0,0 +1,2 @@ +mod riscv_rv32m_mul_divu_remu_prove_verify; +mod rv32m_sidecar_sparse_steps; diff --git a/crates/neo-fold/tests/suites/rv32m/riscv_rv32m_mul_divu_remu_prove_verify.rs b/crates/neo-fold/tests/suites/rv32m/riscv_rv32m_mul_divu_remu_prove_verify.rs new file mode 100644 index 00000000..eb06a529 --- /dev/null +++ b/crates/neo-fold/tests/suites/rv32m/riscv_rv32m_mul_divu_remu_prove_verify.rs @@ -0,0 +1,134 @@ +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +#[test] +fn rv32_trace_prove_verify_add_sub_sequence() { + let program = vec![ + // x1 = 7 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 7, + }, + // x2 = 13 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 13, + }, + // x3 = x1 + x2 = 20 + RiscvInstruction::RAlu { + op: RiscvOpcode::Add, + rd: 3, + rs1: 1, + rs2: 2, + }, + // x4 = x2 - x1 = 6 + RiscvInstruction::RAlu { + op: RiscvOpcode::Sub, + rd: 4, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(program.len()) + .min_trace_len(program.len()) + .max_steps(program.len()) + .reg_output_claim(/*reg=*/ 3, /*expected=*/ F::from_u64(20)) + .reg_output_claim(/*reg=*/ 4, /*expected=*/ F::from_u64(6)) + .prove() + .expect("prove"); + run.verify().expect("verify"); +} + +#[test] +fn rv32_trace_prove_verify_sltu_and_zero_flag_path() { + let program = vec![ + // x1 = 5 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 5, + }, + // x2 = 5 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 5, + }, + // x3 = (x1 < x2) ? 1 : 0 => 0 + RiscvInstruction::RAlu { + op: RiscvOpcode::Sltu, + rd: 3, + rs1: 1, + rs2: 2, + }, + // x4 = x3 + 1 => 1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 4, + rs1: 3, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(program.len()) + .min_trace_len(program.len()) + .max_steps(program.len()) + .reg_output_claim(/*reg=*/ 3, /*expected=*/ F::from_u64(0)) + .reg_output_claim(/*reg=*/ 4, /*expected=*/ F::from_u64(1)) + .prove() + .expect("prove"); + run.verify().expect("verify"); +} + +#[test] +fn rv32_trace_prove_verify_signed_compare_path() { + let program = vec![ + // x1 = -7 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: -7, + }, + // x2 = 3 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 3, + }, + // x3 = (x1 < x2) signed => 1 + RiscvInstruction::RAlu { + op: RiscvOpcode::Slt, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(program.len()) + .min_trace_len(program.len()) + .max_steps(program.len()) + .reg_output_claim(/*reg=*/ 3, /*expected=*/ F::from_u64(1)) + .prove() + .expect("prove"); + run.verify().expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/rv32m/rv32m_sidecar_sparse_steps.rs b/crates/neo-fold/tests/suites/rv32m/rv32m_sidecar_sparse_steps.rs new file mode 100644 index 00000000..b4cb2daf --- /dev/null +++ b/crates/neo-fold/tests/suites/rv32m/rv32m_sidecar_sparse_steps.rs @@ -0,0 +1,117 @@ +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; + +fn run_trace(program: &[RiscvInstruction]) -> neo_fold::riscv_trace_shard::Rv32TraceWiringRun { + let program_bytes = encode_program(program); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(program.len()) + .min_trace_len(program.len()) + .max_steps(program.len()) + .prove() + .expect("prove"); + run.verify().expect("verify"); + run +} + +#[test] +fn trace_program_without_ram_ops_has_no_ram_events() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let run = run_trace(&program); + let ram_rows = run + .exec_table() + .rows + .iter() + .filter(|row| !row.ram_events.is_empty()) + .count(); + assert_eq!(ram_rows, 0, "expected no RAM events in non-memory program"); +} + +#[test] +fn trace_rows_are_sparse_over_time_for_store_load() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 12, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 0, + }, + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let run = run_trace(&program); + + let ram_rows: Vec = run + .exec_table() + .rows + .iter() + .enumerate() + .filter_map(|(idx, row)| (!row.ram_events.is_empty()).then_some(idx)) + .collect(); + assert_eq!(ram_rows, vec![1, 2], "expected RAM rows only on SW/LW steps"); +} + +#[test] +fn trace_rows_select_only_expected_opcodes() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 3, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 2, + rs1: 1, + imm: 1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sub, + rd: 3, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let run = run_trace(&program); + + let op_rows: Vec = run + .exec_table() + .rows + .iter() + .enumerate() + .filter_map(|(idx, row)| { + matches!( + row.decoded, + Some(RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + .. + }) | Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Sub, + .. + }) + ) + .then_some(idx) + }) + .collect(); + assert_eq!(op_rows, vec![1, 2], "expected OR/SUB rows at indices 1 and 2"); +} diff --git a/crates/neo-fold/tests/suites/session/mod.rs b/crates/neo-fold/tests/suites/session/mod.rs new file mode 100644 index 00000000..797b7487 --- /dev/null +++ b/crates/neo-fold/tests/suites/session/mod.rs @@ -0,0 +1,10 @@ +mod session_basic_crosscheck; +mod session_basic_optimized_engine; +mod session_basic_paper_exact; +mod session_layout_dsl_tests; +mod session_multifold_r1cs_crosscheck; +mod session_multifold_r1cs_optimized; +mod session_multifold_r1cs_paper_exact; +mod session_multifold_r1cs_paper_exact_nontrivial; +mod session_step_linking_policy; +mod session_ux_helpers; diff --git a/crates/neo-fold/tests/session_basic_crosscheck.rs b/crates/neo-fold/tests/suites/session/session_basic_crosscheck.rs similarity index 100% rename from crates/neo-fold/tests/session_basic_crosscheck.rs rename to crates/neo-fold/tests/suites/session/session_basic_crosscheck.rs diff --git a/crates/neo-fold/tests/session_basic_optimized_engine.rs b/crates/neo-fold/tests/suites/session/session_basic_optimized_engine.rs similarity index 100% rename from crates/neo-fold/tests/session_basic_optimized_engine.rs rename to crates/neo-fold/tests/suites/session/session_basic_optimized_engine.rs diff --git a/crates/neo-fold/tests/session_basic_paper_exact.rs b/crates/neo-fold/tests/suites/session/session_basic_paper_exact.rs similarity index 100% rename from crates/neo-fold/tests/session_basic_paper_exact.rs rename to crates/neo-fold/tests/suites/session/session_basic_paper_exact.rs diff --git a/crates/neo-fold/tests/session_layout_dsl_tests.rs b/crates/neo-fold/tests/suites/session/session_layout_dsl_tests.rs similarity index 100% rename from crates/neo-fold/tests/session_layout_dsl_tests.rs rename to crates/neo-fold/tests/suites/session/session_layout_dsl_tests.rs diff --git a/crates/neo-fold/tests/session_multifold_r1cs_crosscheck.rs b/crates/neo-fold/tests/suites/session/session_multifold_r1cs_crosscheck.rs similarity index 100% rename from crates/neo-fold/tests/session_multifold_r1cs_crosscheck.rs rename to crates/neo-fold/tests/suites/session/session_multifold_r1cs_crosscheck.rs diff --git a/crates/neo-fold/tests/session_multifold_r1cs_optimized.rs b/crates/neo-fold/tests/suites/session/session_multifold_r1cs_optimized.rs similarity index 100% rename from crates/neo-fold/tests/session_multifold_r1cs_optimized.rs rename to crates/neo-fold/tests/suites/session/session_multifold_r1cs_optimized.rs diff --git a/crates/neo-fold/tests/session_multifold_r1cs_paper_exact.rs b/crates/neo-fold/tests/suites/session/session_multifold_r1cs_paper_exact.rs similarity index 100% rename from crates/neo-fold/tests/session_multifold_r1cs_paper_exact.rs rename to crates/neo-fold/tests/suites/session/session_multifold_r1cs_paper_exact.rs diff --git a/crates/neo-fold/tests/session_multifold_r1cs_paper_exact_nontrivial.rs b/crates/neo-fold/tests/suites/session/session_multifold_r1cs_paper_exact_nontrivial.rs similarity index 100% rename from crates/neo-fold/tests/session_multifold_r1cs_paper_exact_nontrivial.rs rename to crates/neo-fold/tests/suites/session/session_multifold_r1cs_paper_exact_nontrivial.rs diff --git a/crates/neo-fold/tests/session_step_linking_policy.rs b/crates/neo-fold/tests/suites/session/session_step_linking_policy.rs similarity index 100% rename from crates/neo-fold/tests/session_step_linking_policy.rs rename to crates/neo-fold/tests/suites/session/session_step_linking_policy.rs diff --git a/crates/neo-fold/tests/session_ux_helpers.rs b/crates/neo-fold/tests/suites/session/session_ux_helpers.rs similarity index 100% rename from crates/neo-fold/tests/session_ux_helpers.rs rename to crates/neo-fold/tests/suites/session/session_ux_helpers.rs diff --git a/crates/neo-fold/tests/cpu_bus_semantics_fork_attack.rs b/crates/neo-fold/tests/suites/shared_bus/cpu_bus_semantics_fork_attack.rs similarity index 93% rename from crates/neo-fold/tests/cpu_bus_semantics_fork_attack.rs rename to crates/neo-fold/tests/suites/shared_bus/cpu_bus_semantics_fork_attack.rs index c8ce08cc..8f40a7eb 100644 --- a/crates/neo-fold/tests/cpu_bus_semantics_fork_attack.rs +++ b/crates/neo-fold/tests/suites/shared_bus/cpu_bus_semantics_fork_attack.rs @@ -33,77 +33,28 @@ #![allow(non_snake_case)] use std::marker::PhantomData; -use std::sync::Arc; -use neo_ajtai::{s_lincomb, s_mul, setup as ajtai_setup, AjtaiSModule, Commitment as Cmt}; +use neo_ajtai::Commitment as Cmt; use neo_ccs::relations::{CcsStructure, McsInstance, McsWitness}; use neo_ccs::traits::SModuleHomomorphism; use neo_ccs::Mat; use neo_fold::pi_ccs::FoldingMode; use neo_fold::shard::{ fold_shard_prove as fold_shard_prove_shared_cpu_bus, fold_shard_verify as fold_shard_verify_shared_cpu_bus, - CommitMixers, }; -use neo_math::ring::Rq as RqEl; -use neo_math::{D, F, K}; +use neo_math::{F, K}; use neo_memory::plain::{PlainMemLayout, PlainMemTrace}; use neo_memory::witness::{LutInstance, LutWitness, MemInstance, MemWitness, StepInstanceBundle, StepWitnessBundle}; use neo_memory::MemInit; use neo_params::NeoParams; use neo_transcript::{Poseidon2Transcript, Transcript}; use p3_field::PrimeCharacteristicRing; -use rand_chacha::rand_core::SeedableRng; -use rand_chacha::ChaCha8Rng; // ============================================================================ // Test Infrastructure // ============================================================================ -fn setup_ajtai_committer(params: &NeoParams, m: usize) -> AjtaiSModule { - let mut rng = ChaCha8Rng::seed_from_u64(7); - let pp = ajtai_setup(&mut rng, D, params.kappa as usize, m).expect("Ajtai setup should succeed"); - AjtaiSModule::new(Arc::new(pp)) -} - -fn rot_matrix_to_rq(mat: &Mat) -> RqEl { - use neo_math::ring::cf_inv; - - debug_assert_eq!(mat.rows(), D); - debug_assert_eq!(mat.cols(), D); - - let mut coeffs = [F::ZERO; D]; - for i in 0..D { - coeffs[i] = mat[(i, 0)]; - } - cf_inv(coeffs) -} - -fn default_mixers() -> CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt> { - fn mix_rhos_commits(rhos: &[Mat], cs: &[Cmt]) -> Cmt { - assert!(!cs.is_empty(), "mix_rhos_commits: empty commitments"); - if cs.len() == 1 { - return cs[0].clone(); - } - let rq_els: Vec = rhos.iter().map(rot_matrix_to_rq).collect(); - s_lincomb(&rq_els, cs).expect("s_lincomb should succeed") - } - fn combine_b_pows(cs: &[Cmt], b: u32) -> Cmt { - assert!(!cs.is_empty(), "combine_b_pows: empty commitments"); - let mut acc = cs[0].clone(); - let mut pow = F::from_u64(b as u64); - for i in 1..cs.len() { - let rq_pow = RqEl::from_field_scalar(pow); - let term = s_mul(&rq_pow, &cs[i]); - acc.add_inplace(&term); - pow *= F::from_u64(b as u64); - } - acc - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } -} +use crate::suite::{default_mixers, setup_ajtai_committer}; // ============================================================================ // Test: CPU/Memory Semantic Fork Attack @@ -264,6 +215,7 @@ fn cpu_semantic_shadow_fork_attack_should_be_rejected() { }; let mem_inst = MemInstance:: { + mem_id: 0, comms: Vec::new(), k: mem_layout.k, d: mem_layout.d, @@ -462,6 +414,7 @@ fn cpu_semantic_fork_splice_attack_should_be_rejected() { }; let mem_inst = MemInstance:: { + mem_id: 0, comms: Vec::new(), k: mem_layout.k, d: mem_layout.d, @@ -641,6 +594,7 @@ fn cpu_lookup_shadow_fork_attack_should_be_rejected() { }; let lut_inst = LutInstance:: { + table_id: lut_table.table_id, comms: Vec::new(), k: lut_table.k, d: lut_table.d, @@ -650,6 +604,8 @@ fn cpu_lookup_shadow_fork_attack_should_be_rejected() { ell: lut_table.n_side.trailing_zeros() as usize, table_spec: None, table: lut_table.content.clone(), + addr_group: None, + selector_group: None, }; let lut_wit = LutWitness { mats: Vec::new() }; @@ -672,6 +628,7 @@ fn cpu_lookup_shadow_fork_attack_should_be_rejected() { inc_at_write_addr: vec![F::ZERO], }; let mem_inst = MemInstance:: { + mem_id: 0, comms: Vec::new(), k: mem_layout.k, d: mem_layout.d, diff --git a/crates/neo-fold/tests/cpu_constraints_fix_vulnerabilities.rs b/crates/neo-fold/tests/suites/shared_bus/cpu_constraints_fix_vulnerabilities.rs similarity index 99% rename from crates/neo-fold/tests/cpu_constraints_fix_vulnerabilities.rs rename to crates/neo-fold/tests/suites/shared_bus/cpu_constraints_fix_vulnerabilities.rs index 983bd20d..ca1fa743 100644 --- a/crates/neo-fold/tests/cpu_constraints_fix_vulnerabilities.rs +++ b/crates/neo-fold/tests/suites/shared_bus/cpu_constraints_fix_vulnerabilities.rs @@ -400,7 +400,7 @@ fn shout_constraints_catch_lookup_attacks() { z[COL_LOOKUP_OUT] = F::from_u64(123); let bus_has_lookup = bus.bus_cell(shout.has_lookup, 0); - let bus_val = bus.bus_cell(shout.val, 0); + let bus_val = bus.bus_cell(shout.primary_val(), 0); z[bus_has_lookup] = F::ONE; z[bus_val] = F::from_u64(456); // Mismatch! @@ -418,7 +418,7 @@ fn shout_constraints_catch_lookup_attacks() { z[COL_CONST_ONE] = F::ONE; let bus_has_lookup = bus.bus_cell(shout.has_lookup, 0); - let bus_val = bus.bus_cell(shout.val, 0); + let bus_val = bus.bus_cell(shout.primary_val(), 0); z[bus_has_lookup] = F::ZERO; z[bus_val] = F::from_u64(999); // Should be 0! @@ -438,7 +438,7 @@ fn shout_constraints_catch_lookup_attacks() { z[COL_LOOKUP_OUT] = F::from_u64(42); let bus_has_lookup = bus.bus_cell(shout.has_lookup, 0); - let bus_val = bus.bus_cell(shout.val, 0); + let bus_val = bus.bus_cell(shout.primary_val(), 0); z[bus_has_lookup] = F::ONE; z[bus_val] = F::from_u64(42); // Matches! @@ -543,7 +543,7 @@ fn lookup_key_binding_catches_mismatch() { // Bus: has_lookup=1, val matches, but addr_bits encode 4 (mismatch) z[bus.bus_cell(shout.has_lookup, 0)] = F::ONE; - z[bus.bus_cell(shout.val, 0)] = F::from_u64(42); + z[bus.bus_cell(shout.primary_val(), 0)] = F::from_u64(42); let addr_base = bus.bus_cell(shout.addr_bits.start, 0); // 4 = 0b0100 (little-endian bits: [0,0,1,0]) diff --git a/crates/neo-fold/tests/suites/shared_bus/mod.rs b/crates/neo-fold/tests/suites/shared_bus/mod.rs new file mode 100644 index 00000000..90343f32 --- /dev/null +++ b/crates/neo-fold/tests/suites/shared_bus/mod.rs @@ -0,0 +1,12 @@ +pub(crate) use crate::common_setup::{default_mixers, setup_ajtai_committer}; + +mod cpu_bus_semantics_fork_attack; +mod cpu_constraints_fix_vulnerabilities; +mod shared_cpu_bus_comprehensive_attacks; +mod shared_cpu_bus_control_attacks; +mod shared_cpu_bus_decode_attacks; +mod shared_cpu_bus_layout_consistency; +mod shared_cpu_bus_linkage; +mod shared_cpu_bus_padding_attacks; +mod shared_cpu_bus_width_attacks; +mod ts_route_a_negative; diff --git a/crates/neo-fold/tests/shared_cpu_bus_comprehensive_attacks.rs b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_comprehensive_attacks.rs similarity index 94% rename from crates/neo-fold/tests/shared_cpu_bus_comprehensive_attacks.rs rename to crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_comprehensive_attacks.rs index 374ac11d..d1c1fde2 100644 --- a/crates/neo-fold/tests/shared_cpu_bus_comprehensive_attacks.rs +++ b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_comprehensive_attacks.rs @@ -22,9 +22,8 @@ #![allow(non_snake_case)] use std::marker::PhantomData; -use std::sync::Arc; -use neo_ajtai::{s_lincomb, s_mul, setup as ajtai_setup, AjtaiSModule, Commitment as Cmt}; +use neo_ajtai::Commitment as Cmt; use neo_ccs::poly::SparsePoly; use neo_ccs::relations::{CcsStructure, McsInstance, McsWitness}; use neo_ccs::traits::SModuleHomomorphism; @@ -32,10 +31,8 @@ use neo_ccs::Mat; use neo_fold::pi_ccs::FoldingMode; use neo_fold::shard::{ fold_shard_prove as fold_shard_prove_shared_cpu_bus, fold_shard_verify as fold_shard_verify_shared_cpu_bus, - CommitMixers, }; -use neo_math::ring::Rq as RqEl; -use neo_math::{D, F, K}; +use neo_math::{F, K}; use neo_memory::cpu::build_bus_layout_for_instances; use neo_memory::cpu::constraints::{CpuColumnLayout, CpuConstraintBuilder}; use neo_memory::plain::{LutTable, PlainLutTrace, PlainMemLayout, PlainMemTrace}; @@ -44,60 +41,15 @@ use neo_memory::MemInit; use neo_params::NeoParams; use neo_transcript::{Poseidon2Transcript, Transcript}; use p3_field::PrimeCharacteristicRing; -use rand_chacha::rand_core::SeedableRng; -use rand_chacha::ChaCha8Rng; // ============================================================================ // Test Infrastructure // ============================================================================ -fn setup_ajtai_committer(params: &NeoParams, m: usize) -> AjtaiSModule { - let mut rng = ChaCha8Rng::seed_from_u64(7); - let pp = ajtai_setup(&mut rng, D, params.kappa as usize, m).expect("Ajtai setup should succeed"); - AjtaiSModule::new(Arc::new(pp)) -} - -fn rot_matrix_to_rq(mat: &Mat) -> RqEl { - use neo_math::ring::cf_inv; - - debug_assert_eq!(mat.rows(), D); - debug_assert_eq!(mat.cols(), D); - - let mut coeffs = [F::ZERO; D]; - for i in 0..D { - coeffs[i] = mat[(i, 0)]; - } - cf_inv(coeffs) -} - -fn default_mixers() -> CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt> { - fn mix_rhos_commits(rhos: &[Mat], cs: &[Cmt]) -> Cmt { - assert!(!cs.is_empty(), "mix_rhos_commits: empty commitments"); - if cs.len() == 1 { - return cs[0].clone(); - } - let rq_els: Vec = rhos.iter().map(rot_matrix_to_rq).collect(); - s_lincomb(&rq_els, cs).expect("s_lincomb should succeed") - } - fn combine_b_pows(cs: &[Cmt], b: u32) -> Cmt { - assert!(!cs.is_empty(), "combine_b_pows: empty commitments"); - let mut acc = cs[0].clone(); - let mut pow = F::from_u64(b as u64); - for i in 1..cs.len() { - let rq_pow = RqEl::from_field_scalar(pow); - let term = s_mul(&rq_pow, &cs[i]); - acc.add_inplace(&term); - pow *= F::from_u64(b as u64); - } - acc - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } -} +use crate::suite::{default_mixers, setup_ajtai_committer}; fn metadata_only_mem_instance( + mem_id: u32, layout: &PlainMemLayout, init: MemInit, steps: usize, @@ -105,6 +57,7 @@ fn metadata_only_mem_instance( let ell = layout.n_side.trailing_zeros() as usize; ( MemInstance { + mem_id, comms: Vec::new(), k: layout.k, d: layout.d, @@ -122,6 +75,7 @@ fn metadata_only_lut_instance(table: &LutTable, steps: usize) -> (LutInstance let ell = table.n_side.trailing_zeros() as usize; ( LutInstance { + table_id: table.table_id, comms: Vec::new(), k: table.k, d: table.d, @@ -131,6 +85,8 @@ fn metadata_only_lut_instance(table: &LutTable, steps: usize) -> (LutInstance ell, table_spec: None, table: table.content.clone(), + addr_group: None, + selector_group: None, }, LutWitness { mats: Vec::new() }, ) @@ -339,7 +295,7 @@ fn ccs_must_reference_bus_columns_guardrail() { }, ); - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, MemInit::Zero, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, MemInit::Zero, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -454,7 +410,7 @@ fn address_bit_tampering_attack_should_be_rejected() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -583,7 +539,7 @@ fn has_read_flag_mismatch_attack_should_be_rejected() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -713,7 +669,7 @@ fn increment_value_tampering_attack_should_be_rejected() { inc_at_write_addr: vec![F::from_u64(100)], // WRONG }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -868,7 +824,7 @@ fn lookup_value_tampering_attack_should_be_rejected() { write_val: vec![F::ZERO], inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -997,7 +953,7 @@ fn bus_region_mismatch_with_twist_trace_should_be_rejected() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -1124,7 +1080,7 @@ fn write_then_read_consistency_attack_should_be_rejected() { }, ); - let (mem_inst1, mem_wit1) = metadata_only_mem_instance(&mem_layout, mem_init_step1, mem_trace_step1.steps); + let (mem_inst1, mem_wit1) = metadata_only_mem_instance(0, &mem_layout, mem_init_step1, mem_trace_step1.steps); // Step 2: ATTACK - Read from addr 0, claim value is 0 (should be 100) let mem_init_step2 = MemInit::Sparse(vec![(0, F::from_u64(100))]); // State after step 1 @@ -1163,7 +1119,7 @@ fn write_then_read_consistency_attack_should_be_rejected() { }, ); - let (mem_inst2, mem_wit2) = metadata_only_mem_instance(&mem_layout, mem_init_step2, mem_trace_step2.steps); + let (mem_inst2, mem_wit2) = metadata_only_mem_instance(0, &mem_layout, mem_init_step2, mem_trace_step2.steps); let steps_witness = vec![ StepWitnessBundle { @@ -1300,7 +1256,7 @@ fn correct_witness_should_verify() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, diff --git a/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_control_attacks.rs b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_control_attacks.rs new file mode 100644 index 00000000..ef5de335 --- /dev/null +++ b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_control_attacks.rs @@ -0,0 +1,164 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::shard::ShardProof; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, BranchCondition, RiscvInstruction, RiscvOpcode}; +use neo_memory::riscv::trace::{rv32_decode_lookup_backed_cols, Rv32DecodeSidecarLayout, Rv32TraceLayout}; +use p3_field::PrimeCharacteristicRing; + +fn prove_control_trace_program(program: Vec) -> (Rv32TraceWiringRun, ShardProof) { + let program_bytes = encode_program(&program); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace wiring prove"); + run.verify().expect("trace wiring verify"); + let proof = run.proof().clone(); + (run, proof) +} + +fn rv32_wp_opening_cols(layout: &Rv32TraceLayout) -> Vec { + vec![ + layout.active, + layout.instr_word, + layout.rs1_addr, + layout.rs1_val, + layout.rs2_addr, + layout.rs2_val, + layout.rd_addr, + layout.rd_val, + layout.ram_addr, + layout.ram_rv, + layout.ram_wv, + layout.shout_has_lookup, + layout.shout_val, + layout.shout_lhs, + layout.shout_rhs, + layout.jalr_drop_bit, + layout.pc_before, + layout.pc_after, + ] +} + +fn tamper_control_decode_opening_scalar(proof: &mut ShardProof, decode_col: usize) { + let layout = Rv32DecodeSidecarLayout::new(); + let decode_open_cols = rv32_decode_lookup_backed_cols(&layout); + assert_eq!( + proof.steps[0].mem.wp_me_claims.len(), + 1, + "expected one WP ME claim carrying decode openings for control stage checks" + ); + let me = &mut proof.steps[0].mem.wp_me_claims[0]; + let decode_start = me + .y_scalars + .len() + .checked_sub(decode_open_cols.len()) + .expect("control stage decode opening shape in WP ME tail"); + let open_idx = decode_open_cols + .iter() + .position(|&c| c == decode_col) + .expect("decode col must be present in control stage decode opening set"); + me.y_scalars[decode_start + open_idx] += K::ONE; +} + +fn tamper_control_wp_opening_scalar(proof: &mut ShardProof, trace_col: usize) { + let layout = Rv32TraceLayout::new(); + let open_cols = rv32_wp_opening_cols(&layout); + let decode_open_cols = rv32_decode_lookup_backed_cols(&Rv32DecodeSidecarLayout::new()); + let open_idx = open_cols + .iter() + .position(|&c| c == trace_col) + .expect("trace col must be present in control stage WP opening set"); + assert_eq!( + proof.steps[0].mem.wp_me_claims.len(), + 1, + "expected one WP ME claim reused by control stage checks" + ); + let me = &mut proof.steps[0].mem.wp_me_claims[0]; + let core_t = me + .y_scalars + .len() + .checked_sub(decode_open_cols.len()) + .expect("control stage decode opening tail shape") + .checked_sub(open_cols.len()) + .expect("control stage WP opening shape"); + me.y_scalars[core_t + open_idx] += K::ONE; +} + +#[test] +fn control_jal_target_tamper_is_rejected() { + let program = vec![ + RiscvInstruction::Jal { rd: 1, imm: 8 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let (run, mut proof) = prove_control_trace_program(program); + let decode = Rv32DecodeSidecarLayout::new(); + tamper_control_decode_opening_scalar(&mut proof, decode.imm_j); + assert!( + run.verify_proof(&proof).is_err(), + "tampered control stage JAL target opening must fail verification" + ); +} + +#[test] +fn control_jalr_target_tamper_is_rejected() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 8, + }, + RiscvInstruction::Jalr { rd: 2, rs1: 1, imm: 0 }, + RiscvInstruction::Halt, + ]; + let (run, mut proof) = prove_control_trace_program(program); + let decode = Rv32DecodeSidecarLayout::new(); + tamper_control_decode_opening_scalar(&mut proof, decode.imm_i); + assert!( + run.verify_proof(&proof).is_err(), + "tampered control stage JALR target opening must fail verification" + ); +} + +#[test] +fn control_branch_decision_target_tamper_is_rejected() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Branch { + cond: BranchCondition::Ne, + rs1: 0, + rs2: 0, + imm: 8, + }, + RiscvInstruction::Halt, + ]; + let (run, mut proof) = prove_control_trace_program(program); + let decode = Rv32DecodeSidecarLayout::new(); + tamper_control_decode_opening_scalar(&mut proof, decode.funct3_bit[0]); + assert!( + run.verify_proof(&proof).is_err(), + "tampered control stage branch decision/target opening must fail verification" + ); +} + +#[test] +fn control_control_writeback_tamper_is_rejected() { + let program = vec![RiscvInstruction::Auipc { rd: 1, imm: 1 }, RiscvInstruction::Halt]; + let (run, mut proof) = prove_control_trace_program(program); + let trace = Rv32TraceLayout::new(); + tamper_control_wp_opening_scalar(&mut proof, trace.rd_val); + assert!( + run.verify_proof(&proof).is_err(), + "tampered control stage control-writeback opening must fail verification" + ); +} diff --git a/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_decode_attacks.rs b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_decode_attacks.rs new file mode 100644 index 00000000..7d3c67e0 --- /dev/null +++ b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_decode_attacks.rs @@ -0,0 +1,83 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::shard::ShardProof; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use neo_memory::riscv::trace::{rv32_decode_lookup_backed_cols, Rv32DecodeSidecarLayout}; +use p3_field::PrimeCharacteristicRing; + +fn prove_decode_trace_program() -> (Rv32TraceWiringRun, ShardProof) { + // Program exercises both ALU-imm and ALU-reg decode/linkage paths. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Add, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace wiring prove"); + run.verify().expect("trace wiring verify"); + + let proof = run.proof().clone(); + (run, proof) +} + +fn tamper_decode_opening_scalar(proof: &mut ShardProof, decode_col: usize) { + let layout = Rv32DecodeSidecarLayout::new(); + let decode_open_cols = rv32_decode_lookup_backed_cols(&layout); + assert_eq!( + proof.steps[0].mem.wp_me_claims.len(), + 1, + "expected one WP ME claim carrying decode openings" + ); + let me = &mut proof.steps[0].mem.wp_me_claims[0]; + let decode_start = me + .y_scalars + .len() + .checked_sub(decode_open_cols.len()) + .expect("decode openings must be appended to WP ME tail"); + let open_idx = decode_open_cols + .iter() + .position(|&c| c == decode_col) + .expect("decode col must be present in WP decode opening tail"); + me.y_scalars[decode_start + open_idx] += K::ONE; +} + +#[test] +fn decode_write_gate_tamper_is_rejected() { + let (run, mut proof) = prove_decode_trace_program(); + let layout = Rv32DecodeSidecarLayout::new(); + tamper_decode_opening_scalar(&mut proof, layout.op_alu_imm); + assert!( + run.verify_proof(&proof).is_err(), + "tampered decode stage opcode-class opening must fail verification" + ); +} + +#[test] +fn decode_alu_table_delta_tamper_is_rejected() { + let (run, mut proof) = prove_decode_trace_program(); + let layout = Rv32DecodeSidecarLayout::new(); + tamper_decode_opening_scalar(&mut proof, layout.rs2); + assert!( + run.verify_proof(&proof).is_err(), + "tampered decode stage rs2-decode opening must fail verification" + ); +} diff --git a/crates/neo-fold/tests/shared_cpu_bus_layout_consistency.rs b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_layout_consistency.rs similarity index 88% rename from crates/neo-fold/tests/shared_cpu_bus_layout_consistency.rs rename to crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_layout_consistency.rs index 54f7dae2..972c0d21 100644 --- a/crates/neo-fold/tests/shared_cpu_bus_layout_consistency.rs +++ b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_layout_consistency.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] -#[path = "common/fixtures.rs"] +#[path = "../../common/fixtures.rs"] mod fixtures; use fixtures::{build_twist_shout_2step_fixture, prove}; @@ -18,7 +18,7 @@ fn shared_cpu_bus_copyout_indices_match_bus_layout() { let step0_inst = &fx.steps_instance[0]; let ccs_out0 = &proof.steps[0].fold.ccs_out[0]; - let s0 = fx.ccs.ensure_identity_first().expect("identity-first"); + let s0 = fx.ccs.clone(); let base_t = s0.t(); let bus = build_bus_layout_for_instances( @@ -47,7 +47,13 @@ fn shared_cpu_bus_copyout_indices_match_bus_layout() { let shout0 = &bus.shout_cols[0].lanes[0]; let twist0 = &bus.twist_cols[0].lanes[0]; - let col_ids = [shout0.has_lookup, shout0.val, twist0.has_write, twist0.wv, twist0.inc]; + let col_ids = [ + shout0.has_lookup, + shout0.primary_val(), + twist0.has_write, + twist0.wv, + twist0.inc, + ]; for col_id in col_ids { let z_idx = bus.bus_cell(col_id, 0); diff --git a/crates/neo-fold/tests/shared_cpu_bus_linkage.rs b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_linkage.rs similarity index 84% rename from crates/neo-fold/tests/shared_cpu_bus_linkage.rs rename to crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_linkage.rs index 7fdc9a4f..d3e8fa34 100644 --- a/crates/neo-fold/tests/shared_cpu_bus_linkage.rs +++ b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_linkage.rs @@ -1,9 +1,8 @@ #![allow(non_snake_case)] use std::marker::PhantomData; -use std::sync::Arc; -use neo_ajtai::{s_lincomb, s_mul, setup as ajtai_setup, AjtaiSModule, Commitment as Cmt}; +use neo_ajtai::{AjtaiSModule, Commitment as Cmt}; use neo_ccs::poly::SparsePoly; use neo_ccs::relations::{CcsStructure, McsInstance, McsWitness}; use neo_ccs::traits::SModuleHomomorphism; @@ -11,62 +10,15 @@ use neo_ccs::Mat; use neo_fold::pi_ccs::FoldingMode; use neo_fold::shard::{fold_shard_prove, fold_shard_verify, CommitMixers}; use neo_fold::PiCcsError; -use neo_math::ring::Rq as RqEl; -use neo_math::{D, F, K}; +use neo_math::{F, K}; use neo_memory::plain::{LutTable, PlainLutTrace, PlainMemLayout, PlainMemTrace}; use neo_memory::witness::{StepInstanceBundle, StepWitnessBundle}; use neo_memory::MemInit; use neo_params::NeoParams; use neo_transcript::{Poseidon2Transcript, Transcript}; use p3_field::PrimeCharacteristicRing; -use rand_chacha::rand_core::SeedableRng; -use rand_chacha::ChaCha8Rng; -fn setup_ajtai_committer(params: &NeoParams, m: usize) -> AjtaiSModule { - let mut rng = ChaCha8Rng::seed_from_u64(7); - let pp = ajtai_setup(&mut rng, D, params.kappa as usize, m).expect("Ajtai setup should succeed"); - AjtaiSModule::new(Arc::new(pp)) -} - -fn rot_matrix_to_rq(mat: &Mat) -> RqEl { - use neo_math::ring::cf_inv; - - debug_assert_eq!(mat.rows(), D); - debug_assert_eq!(mat.cols(), D); - - let mut coeffs = [F::ZERO; D]; - for i in 0..D { - coeffs[i] = mat[(i, 0)]; - } - cf_inv(coeffs) -} - -fn default_mixers() -> CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt> { - fn mix_rhos_commits(rhos: &[Mat], cs: &[Cmt]) -> Cmt { - assert!(!cs.is_empty(), "mix_rhos_commits: empty commitments"); - if cs.len() == 1 { - return cs[0].clone(); - } - let rq_els: Vec = rhos.iter().map(rot_matrix_to_rq).collect(); - s_lincomb(&rq_els, cs).expect("s_lincomb should succeed") - } - fn combine_b_pows(cs: &[Cmt], b: u32) -> Cmt { - assert!(!cs.is_empty(), "combine_b_pows: empty commitments"); - let mut acc = cs[0].clone(); - let mut pow = F::from_u64(b as u64); - for i in 1..cs.len() { - let rq_pow = RqEl::from_field_scalar(pow); - let term = s_mul(&rq_pow, &cs[i]); - acc.add_inplace(&term); - pow *= F::from_u64(b as u64); - } - acc - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } -} +use crate::suite::{default_mixers, setup_ajtai_committer}; fn create_identity_ccs(n: usize) -> CcsStructure { let mat = Mat::identity(n); @@ -241,6 +193,7 @@ fn build_one_step_fixture(seed: u64) -> SharedBusFixture { let lut_ell = lut_table.n_side.trailing_zeros() as usize; let mem_inst = neo_memory::witness::MemInstance:: { + mem_id: 0, comms: Vec::new(), k: mem_layout.k, d: mem_layout.d, @@ -253,6 +206,7 @@ fn build_one_step_fixture(seed: u64) -> SharedBusFixture { let mem_wit = neo_memory::witness::MemWitness { mats: Vec::new() }; let lut_inst = neo_memory::witness::LutInstance:: { + table_id: lut_table.table_id, comms: Vec::new(), k: lut_table.k, d: lut_table.d, @@ -262,6 +216,8 @@ fn build_one_step_fixture(seed: u64) -> SharedBusFixture { ell: lut_ell, table_spec: None, table: lut_table.content.clone(), + addr_group: None, + selector_group: None, }; let lut_wit = neo_memory::witness::LutWitness { mats: Vec::new() }; @@ -298,7 +254,7 @@ fn build_one_step_fixture(seed: u64) -> SharedBusFixture { fn prove_and_verify_shared(fx: &SharedBusFixture) -> Result<(), PiCcsError> { let mut tr = Poseidon2Transcript::new(b"shared-cpu-bus"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr, &fx.params, &fx.ccs, @@ -311,7 +267,7 @@ fn prove_and_verify_shared(fx: &SharedBusFixture) -> Result<(), PiCcsError> { let mut tr_v = Poseidon2Transcript::new(b"shared-cpu-bus"); let _outputs = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_v, &fx.params, &fx.ccs, @@ -335,7 +291,7 @@ fn shared_cpu_bus_tamper_bus_opening_fails() { let mut tr = Poseidon2Transcript::new(b"shared-cpu-bus"); let mut proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr, &fx.params, &fx.ccs, @@ -362,7 +318,7 @@ fn shared_cpu_bus_tamper_bus_opening_fails() { let mut tr_v = Poseidon2Transcript::new(b"shared-cpu-bus"); assert!( fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_v, &fx.params, &fx.ccs, @@ -382,7 +338,7 @@ fn shared_cpu_bus_missing_cpu_me_claim_val_fails() { let mut tr = Poseidon2Transcript::new(b"shared-cpu-bus"); let mut proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr, &fx.params, &fx.ccs, @@ -395,12 +351,12 @@ fn shared_cpu_bus_missing_cpu_me_claim_val_fails() { .expect("prove"); // Shared-bus mode expects CPU ME claims at r_val inside mem proof, so dropping them must fail. - proof.steps[0].mem.cpu_me_claims_val.clear(); + proof.steps[0].mem.val_me_claims.clear(); let mut tr_v = Poseidon2Transcript::new(b"shared-cpu-bus"); assert!( fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_v, &fx.params, &fx.ccs, diff --git a/crates/neo-fold/tests/shared_cpu_bus_padding_attacks.rs b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_padding_attacks.rs similarity index 93% rename from crates/neo-fold/tests/shared_cpu_bus_padding_attacks.rs rename to crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_padding_attacks.rs index d579f2cb..39d9bfb8 100644 --- a/crates/neo-fold/tests/shared_cpu_bus_padding_attacks.rs +++ b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_padding_attacks.rs @@ -21,19 +21,15 @@ #![allow(non_snake_case)] use std::marker::PhantomData; -use std::sync::Arc; -use neo_ajtai::{s_lincomb, s_mul, setup as ajtai_setup, AjtaiSModule, Commitment as Cmt}; +use neo_ajtai::Commitment as Cmt; use neo_ccs::relations::{CcsStructure, McsInstance, McsWitness}; use neo_ccs::traits::SModuleHomomorphism; -use neo_ccs::Mat; use neo_fold::pi_ccs::FoldingMode; use neo_fold::shard::{ fold_shard_prove as fold_shard_prove_shared_cpu_bus, fold_shard_verify as fold_shard_verify_shared_cpu_bus, - CommitMixers, }; -use neo_math::ring::Rq as RqEl; -use neo_math::{D, F, K}; +use neo_math::{F, K}; use neo_memory::cpu::build_bus_layout_for_instances; use neo_memory::cpu::constraints::{CpuColumnLayout, CpuConstraintBuilder}; use neo_memory::plain::{LutTable, PlainLutTrace, PlainMemLayout, PlainMemTrace}; @@ -42,60 +38,15 @@ use neo_memory::MemInit; use neo_params::NeoParams; use neo_transcript::{Poseidon2Transcript, Transcript}; use p3_field::PrimeCharacteristicRing; -use rand_chacha::rand_core::SeedableRng; -use rand_chacha::ChaCha8Rng; // ============================================================================ // Test Infrastructure (same as comprehensive attacks) // ============================================================================ -fn setup_ajtai_committer(params: &NeoParams, m: usize) -> AjtaiSModule { - let mut rng = ChaCha8Rng::seed_from_u64(7); - let pp = ajtai_setup(&mut rng, D, params.kappa as usize, m).expect("Ajtai setup should succeed"); - AjtaiSModule::new(Arc::new(pp)) -} - -fn rot_matrix_to_rq(mat: &Mat) -> RqEl { - use neo_math::ring::cf_inv; - - debug_assert_eq!(mat.rows(), D); - debug_assert_eq!(mat.cols(), D); - - let mut coeffs = [F::ZERO; D]; - for i in 0..D { - coeffs[i] = mat[(i, 0)]; - } - cf_inv(coeffs) -} - -fn default_mixers() -> CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt> { - fn mix_rhos_commits(rhos: &[Mat], cs: &[Cmt]) -> Cmt { - assert!(!cs.is_empty(), "mix_rhos_commits: empty commitments"); - if cs.len() == 1 { - return cs[0].clone(); - } - let rq_els: Vec = rhos.iter().map(rot_matrix_to_rq).collect(); - s_lincomb(&rq_els, cs).expect("s_lincomb should succeed") - } - fn combine_b_pows(cs: &[Cmt], b: u32) -> Cmt { - assert!(!cs.is_empty(), "combine_b_pows: empty commitments"); - let mut acc = cs[0].clone(); - let mut pow = F::from_u64(b as u64); - for i in 1..cs.len() { - let rq_pow = RqEl::from_field_scalar(pow); - let term = s_mul(&rq_pow, &cs[i]); - acc.add_inplace(&term); - pow *= F::from_u64(b as u64); - } - acc - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } -} +use crate::suite::{default_mixers, setup_ajtai_committer}; fn metadata_only_mem_instance( + mem_id: u32, layout: &PlainMemLayout, init: MemInit, steps: usize, @@ -103,6 +54,7 @@ fn metadata_only_mem_instance( let ell = layout.n_side.trailing_zeros() as usize; ( MemInstance { + mem_id, comms: Vec::new(), k: layout.k, d: layout.d, @@ -120,6 +72,7 @@ fn metadata_only_lut_instance(table: &LutTable, steps: usize) -> (LutInstance let ell = table.n_side.trailing_zeros() as usize; ( LutInstance { + table_id: table.table_id, comms: Vec::new(), k: table.k, d: table.d, @@ -129,6 +82,8 @@ fn metadata_only_lut_instance(table: &LutTable, steps: usize) -> (LutInstance ell, table_spec: None, table: table.content.clone(), + addr_group: None, + selector_group: None, }, LutWitness { mats: Vec::new() }, ) @@ -267,7 +222,7 @@ fn has_write_flag_mismatch_wv_nonzero_should_be_rejected() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -388,7 +343,7 @@ fn has_write_flag_mismatch_inc_nonzero_should_be_rejected() { inc_at_write_addr: vec![F::from_u64(50)], // Attack }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -509,7 +464,7 @@ fn has_read_flag_mismatch_ra_bits_nonzero_should_be_rejected() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -630,7 +585,7 @@ fn has_write_flag_mismatch_wa_bits_nonzero_should_be_rejected() { inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -775,7 +730,7 @@ fn has_lookup_flag_mismatch_val_nonzero_should_be_rejected() { write_val: vec![F::ZERO], inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, @@ -916,7 +871,7 @@ fn has_lookup_flag_mismatch_addr_bits_nonzero_should_be_rejected() { write_val: vec![F::ZERO], inc_at_write_addr: vec![F::ZERO], }; - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, mem_init, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, mem_init, mem_trace.steps); let steps_witness = vec![StepWitnessBundle { mcs, diff --git a/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_width_attacks.rs b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_width_attacks.rs new file mode 100644 index 00000000..7b0e8db2 --- /dev/null +++ b/crates/neo-fold/tests/suites/shared_bus/shared_cpu_bus_width_attacks.rs @@ -0,0 +1,98 @@ +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::shard::ShardProof; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; +use neo_memory::riscv::trace::{rv32_width_lookup_backed_cols, Rv32WidthSidecarLayout}; +use p3_field::PrimeCharacteristicRing; + +fn prove_width_trace_program() -> (Rv32TraceWiringRun, ShardProof) { + // Program exercises load/store selector and width semantics: + // ADDI x1, x0, 1 + // SW x1, 0(x0) + // LW x2, 0(x0) + // HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 0, + }, + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .prove() + .expect("trace wiring prove"); + run.verify().expect("trace wiring verify"); + + let proof = run.proof().clone(); + (run, proof) +} + +fn tamper_width_opening_scalar(proof: &mut ShardProof, width_col: usize) { + let layout = Rv32WidthSidecarLayout::new(); + let width_open_cols = rv32_width_lookup_backed_cols(&layout); + assert_eq!( + proof.steps[0].mem.wp_me_claims.len(), + 1, + "expected one WP ME claim carrying width lookup openings" + ); + let me = &mut proof.steps[0].mem.wp_me_claims[0]; + let width_open_start = me + .y_scalars + .len() + .checked_sub(width_open_cols.len()) + .expect("width openings must be appended to WP ME tail"); + let width_idx = width_open_cols + .iter() + .position(|&c| c == width_col) + .expect("expected width lookup opening column"); + me.y_scalars[width_open_start + width_idx] += K::ONE; +} + +#[test] +fn width_low_bit_tamper_is_rejected() { + let (run, mut proof) = prove_width_trace_program(); + let layout = Rv32WidthSidecarLayout::new(); + tamper_width_opening_scalar(&mut proof, layout.ram_rv_low_bit[0]); + assert!( + run.verify_proof(&proof).is_err(), + "tampered width stage low-bit opening must fail verification" + ); +} + +#[test] +fn width_load_semantics_tamper_is_rejected() { + let (run, mut proof) = prove_width_trace_program(); + let layout = Rv32WidthSidecarLayout::new(); + tamper_width_opening_scalar(&mut proof, layout.ram_rv_q16); + assert!( + run.verify_proof(&proof).is_err(), + "tampered width stage load-semantics opening must fail verification" + ); +} + +#[test] +fn width_store_semantics_tamper_is_rejected() { + let (run, mut proof) = prove_width_trace_program(); + let layout = Rv32WidthSidecarLayout::new(); + tamper_width_opening_scalar(&mut proof, layout.rs2_low_bit[0]); + assert!( + run.verify_proof(&proof).is_err(), + "tampered width stage store-semantics opening must fail verification" + ); +} diff --git a/crates/neo-fold/tests/ts_route_a_negative.rs b/crates/neo-fold/tests/suites/shared_bus/ts_route_a_negative.rs similarity index 100% rename from crates/neo-fold/tests/ts_route_a_negative.rs rename to crates/neo-fold/tests/suites/shared_bus/ts_route_a_negative.rs diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/mod.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/mod.rs new file mode 100644 index 00000000..cd1c3c8b --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/mod.rs @@ -0,0 +1,16 @@ +mod riscv_trace_shout_bitwise_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_div_rem_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_divu_remu_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_eq_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_event_table_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_mul_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_mulh_mulhsu_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_mulhu_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_sll_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_slt_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_sltu_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_sra_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_srl_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_sub_no_shared_cpu_bus_e2e; +mod riscv_trace_shout_xor_no_shared_cpu_bus_e2e; diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_bitwise_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_bitwise_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..1764a57f --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_bitwise_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,265 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, interleave_bits, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, + RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use neo_vm_trace::ShoutEvent; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z_packed_bitwise( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lane_data: &neo_memory::riscv::trace::ShoutLaneOverTime, + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 34 { + return Err(format!( + "build_shout_only_bus_z_packed_bitwise: expected ell_addr=34 (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z_packed_bitwise: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.has_lookup.len() != t { + return Err("build_shout_only_bus_z_packed_bitwise: lane length mismatch".into()); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z_packed_bitwise: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + for j in 0..t { + let has = lane_data.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane_data.value[j]) } else { F::ZERO }; + + // Packed-key layout (ell_addr=34): + // [lhs_u32, rhs_u32, lhs_digits[0..16], rhs_digits[0..16]] where each digit is base-4 in {0,1,2,3}. + let mut packed = [F::ZERO; 34]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane_data.key[j] as u128); + let lhs_u32 = lhs_u64 as u32; + let rhs_u32 = rhs_u64 as u32; + + packed[0] = F::from_u64(lhs_u32 as u64); + packed[1] = F::from_u64(rhs_u32 as u64); + + for i in 0..16usize { + let a = (lhs_u32 >> (2 * i)) & 3; + let b = (rhs_u32 >> (2 * i)) & 3; + packed[2 + i] = F::from_u64(a as u64); + packed[18 + i] = F::from_u64(b as u64); + } + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_bitwise_packed_prove_verify() { + // Program: + // - LUI x1, 0x80000 (x1 = 0x80000000) + // - XORI x2, x0, 1 (x2 = 1) + // - OR x3, x1, x2 (x3 = 0x80000001) + // - ANDI x4, x3, 3 (x4 = 1) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::And, + rd: 4, + rs1: 3, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let mut exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + // Inject a single ANDN event into an otherwise shout-free row so we can exercise packed ANDN. + { + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let shout_id = shout.opcode_to_id(RiscvOpcode::Andn); + let lhs: u32 = 0x8000_0001; + let rhs: u32 = 0x0000_0003; + let val: u32 = lhs & !rhs; + exec.rows[0].shout_events.clear(); + exec.rows[0].shout_events.push(ShoutEvent { + shout_id, + key: interleave_bits(lhs as u64, rhs as u64) as u64, + value: val as u64, + }); + } + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 34usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instances: AND/ANDN/OR/XOR packed, 1 lane each. + let t = exec.rows.len(); + let shout_table_ids = vec![ + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::And).0, + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Or).0, + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Xor).0, + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Andn).0, + ]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + assert_eq!(shout_lanes.len(), 4); + + let mut lut_instances = Vec::new(); + for (idx, opcode) in [RiscvOpcode::And, RiscvOpcode::Or, RiscvOpcode::Xor, RiscvOpcode::Andn] + .into_iter() + .enumerate() + { + let inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 34, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { opcode, xlen: 32 }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let z = build_shout_only_bus_z_packed_bitwise(ccs.m, layout.m_in, t, inst.d * inst.ell, &shout_lanes[idx], &x) + .expect("packed bitwise z"); + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + let c = l.commit(&Z); + let inst = LutInstance:: { comms: vec![c], ..inst }; + let wit = LutWitness { mats: vec![Z] }; + lut_instances.push((inst, wit)); + } + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances, + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-bitwise-packed"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-bitwise-packed"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_div_rem_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_div_rem_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..e0df9832 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_div_rem_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,93 @@ +#![allow(non_snake_case)] + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_div_rem_packed_prove_verify() { + // Program: + // - x1 = -7*4096, x2 = 3*4096 (DIV=-2, REM=-4096) + // - x1 = -1*4096, x2 = 3*4096 (DIV=0, REM=-4096) + // - x1 = INT_MIN, x2 = -1 (DIV overflow case; REM=0) + // - x1 = INT_MIN, x2 = 0 (DIV by zero => -1; REM by zero => lhs) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: -7 }, + RiscvInstruction::Lui { rd: 2, imm: 3 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 4, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Lui { rd: 1, imm: -1 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 5, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 6, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Lui { rd: 1, imm: -524_288 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: -1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 7, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 8, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 9, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 10, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(1) + .max_steps(program.len()) + .reg_output_claim(/*reg=*/ 9, F::from_u64(0xffff_ffff)) + .reg_output_claim(/*reg=*/ 10, F::from_u64(0x8000_0000)) + .prove() + .expect("rv32 trace prove (WB/WP route, DIV/REM)"); + run.verify().expect("rv32 trace verify (WB/WP route, DIV/REM)"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_divu_remu_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_divu_remu_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..baf41219 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_divu_remu_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,77 @@ +#![allow(non_snake_case)] + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_divu_remu_packed_prove_verify() { + // Program: + // - x1 = 91 + // - x2 = 7 + // - DIVU x3, x1, x2 (13) + // - REMU x4, x1, x2 (0) + // - x2 = 0 + // - DIVU x5, x1, x2 (0xffffffff) + // - REMU x6, x1, x2 (91) + // - HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 91, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 7, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + rd: 4, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + rd: 5, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + rd: 6, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(1) + .max_steps(program.len()) + .reg_output_claim(/*reg=*/ 3, F::from_u64(13)) + .reg_output_claim(/*reg=*/ 4, F::from_u64(0)) + .reg_output_claim(/*reg=*/ 5, F::from_u64(0xffff_ffff)) + .reg_output_claim(/*reg=*/ 6, F::from_u64(91)) + .prove() + .expect("rv32 trace prove (WB/WP route, DIVU/REMU)"); + run.verify() + .expect("rv32 trace verify (WB/WP route, DIVU/REMU)"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_eq_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_eq_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..fb15121e --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_eq_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,259 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemory, + RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 35 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=35 for packed EQ (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_shout_only_bus_z: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + // Packed-key layout (ell_addr=35): + // [lhs_u32, rhs_u32, borrow_bit, diff_bits[0..32]]. + let mut packed = [F::ZERO; 35]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let borrow = if lhs < rhs { 1u32 } else { 0u32 }; + let diff = lhs.wrapping_sub(rhs); + + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = if borrow == 1 { F::ONE } else { F::ZERO }; + for bit in 0..32usize { + packed[3 + bit] = if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + } + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_eq_prove_verify() { + // Program (use BEQ to generate `Eq` Shout events; there is no dedicated `Eq` ALU instruction encoding): + // - LUI x1, 0 (x1 = 0) + // - LUI x2, 1 (x2 = 4096) + // - BEQ x1, x2, +8 (not taken; EQ=0) + // - LUI x2, 0 (x2 = 0) + // - BEQ x1, x2, +8 (taken; EQ=1) + // - NOP (skipped) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0 }, + RiscvInstruction::Lui { rd: 2, imm: 1 }, + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 1, + rs2: 2, + imm: 8, + }, + RiscvInstruction::Lui { rd: 2, imm: 0 }, + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 1, + rs2: 2, + imm: 8, + }, + RiscvInstruction::Nop, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + // Legacy no-shared packed Shout tests need enough witness width to host the packed bus lane. + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 35usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: EQ table, 1 lane. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Eq).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let eq_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 35, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Eq, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let eq_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ eq_lut_inst.d * eq_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("EQ Shout z"); + let eq_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &eq_z); + let eq_c = l.commit(&eq_Z); + let eq_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![eq_c], + ..eq_lut_inst + }; + let eq_lut_wit = LutWitness { mats: vec![eq_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(eq_lut_inst, eq_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-eq"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-eq"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_event_table_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_event_table_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..58cdf0de --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_event_table_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,224 @@ +#![allow(non_snake_case)] + +#[path = "../../../common/riscv_shout_event_table_packed.rs"] +mod event_table_packed; + +use std::collections::BTreeMap; +use std::marker::PhantomData; + +use crate::suite::{default_mixers, setup_ajtai_committer}; +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::{Rv32ExecTable, Rv32ShoutEventRow, Rv32ShoutEventTable}; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_event_table_packed_prove_verify() { + // Compact program that still exercises event-table packed mode over multiple opcode families. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 5, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 7, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 1, + rs2: 1, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 64).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Event table extraction. + let event_table = Rv32ShoutEventTable::from_exec_table(&exec).expect("Rv32ShoutEventTable::from_exec_table"); + assert!(!event_table.rows.is_empty(), "expected non-empty Shout event table"); + + // Group by shout_id (stable) and build one event-table packed instance per opcode. + let mut by_id: BTreeMap)> = BTreeMap::new(); + for row in event_table.rows.iter() { + let opcode = row + .opcode + .ok_or_else(|| format!("missing opcode for shout_id={}", row.shout_id)) + .unwrap(); + let entry = by_id + .entry(row.shout_id) + .or_insert_with(|| (opcode, Vec::new())); + if entry.0 != opcode { + panic!( + "opcode mismatch for shout_id={}: {:?} vs {:?}", + row.shout_id, entry.0, opcode + ); + } + entry.1.push(row.clone()); + } + + let ell_n = event_table_packed::ell_n_from_ccs_n(ccs.n); + assert!(ell_n >= 3, "event-table packed requires ell_n>=3 (got ell_n={ell_n})"); + assert!(ell_n <= 64, "event-table packed requires ell_n<=64 (got ell_n={ell_n})"); + + let tables = RiscvShoutTables::new(32); + let expected: BTreeMap = [ + (RiscvOpcode::Add, 2usize), + (RiscvOpcode::Or, 1), + (RiscvOpcode::Eq, 1), + ] + .into_iter() + .map(|(op, count)| (tables.opcode_to_id(op).0, (op, count))) + .collect(); + let got: BTreeMap = by_id + .iter() + .map(|(shout_id, (opcode, rows))| (*shout_id, (*opcode, rows.len()))) + .collect(); + assert_eq!(got, expected, "unexpected event-table opcode coverage"); + + let mut lut_instances: Vec<(LutInstance, LutWitness)> = Vec::new(); + for (_shout_id, (opcode, rows)) in by_id.into_iter() { + let steps = rows.len(); + let z = event_table_packed::build_shout_event_table_bus_z(ccs.m, layout.m_in, steps, ell_n, opcode, &rows, &x) + .unwrap_or_else(|e| panic!("event-table z build failed (opcode={opcode:?}): {e}")); + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + let c = l.commit(&Z); + + // `d = time_bits + base_d(opcode)` + let base_d = + event_table_packed::rv32_packed_base_d(opcode).unwrap_or_else(|e| panic!("opcode {opcode:?}: {e}")); + let d = ell_n + base_d; + + let inst = LutInstance:: { + table_id: 0, + comms: vec![c], + k: 0, + d, + n_side: 2, + steps, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodeEventTablePacked { + opcode, + xlen: 32, + time_bits: ell_n, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let wit = LutWitness { mats: vec![Z] }; + lut_instances.push((inst, wit)); + } + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances, + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-event-table-packed"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + assert!( + !proof.steps[0].mem.shout_me_claims_time.is_empty(), + "expected Shout ME(time) claims in no-shared-bus mode" + ); + assert_eq!( + proof.steps[0].mem.shout_me_claims_time.len(), + steps_witness[0].lut_instances.len(), + "expected 1 Shout ME(time) claim per Shout instance" + ); + assert_eq!( + proof.steps[0].shout_time_fold.len(), + steps_witness[0].lut_instances.len(), + "expected 1 shout_time_fold per Shout instance" + ); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-event-table-packed"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mul_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mul_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..9ab376f3 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mul_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,43 @@ +#![allow(non_snake_case)] + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_mul_prove_verify() { + // Program: + // - LUI x1, 16 (x1 = 65536) + // - LUI x2, 16 (x2 = 65536) + // - MUL x3, x1, x2 (lo = 0) + // - MUL x4, x2, x1 (lo = 0) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 16 }, + RiscvInstruction::Lui { rd: 2, imm: 16 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 4, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(1) + .max_steps(program.len()) + .reg_output_claim(/*reg=*/ 3, F::from_u64(0)) + .reg_output_claim(/*reg=*/ 4, F::from_u64(0)) + .prove() + .expect("rv32 trace prove (WB/WP route, MUL)"); + run.verify().expect("rv32 trace verify (WB/WP route, MUL)"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mulh_mulhsu_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mulh_mulhsu_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..a23ac9aa --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mulh_mulhsu_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,61 @@ +#![allow(non_snake_case)] + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_mulh_mulhsu_packed_prove_verify() { + // Program: + // - x1 = -7 + // - x2 = -3 + // - MULH x3, x1, x2 (0) + // - x5 = 13 + // - MULHSU x6, x1, x5 (0xffffffff) + // - HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: -7, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: -3, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulh, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 5, + rs1: 0, + imm: 13, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhsu, + rd: 6, + rs1: 1, + rs2: 5, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(1) + .max_steps(program.len()) + .reg_output_claim(/*reg=*/ 3, F::from_u64(0)) + .reg_output_claim(/*reg=*/ 6, F::from_u64(0xffff_ffff)) + .prove() + .expect("rv32 trace prove (WB/WP route, MULH/MULHSU)"); + run.verify() + .expect("rv32 trace verify (WB/WP route, MULH/MULHSU)"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mulhu_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mulhu_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..baf7d2e5 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_mulhu_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,43 @@ +#![allow(non_snake_case)] + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_mulhu_prove_verify() { + // Program: + // - LUI x1, 16 (x1 = 65536) + // - LUI x2, 16 (x2 = 65536) + // - MULHU x3, x1, x2 (hi = 1) + // - MULHU x4, x2, x1 (hi = 1) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 16 }, + RiscvInstruction::Lui { rd: 2, imm: 16 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhu, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhu, + rd: 4, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .chunk_rows(1) + .max_steps(program.len()) + .reg_output_claim(/*reg=*/ 3, F::from_u64(1)) + .reg_output_claim(/*reg=*/ 4, F::from_u64(1)) + .prove() + .expect("rv32 trace prove (WB/WP route, MULHU)"); + run.verify().expect("rv32 trace verify (WB/WP route, MULHU)"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..1cdea885 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,264 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 3 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=3 for packed ADD (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_shout_only_bus_z: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + // Packed-key layout: [lhs_u32, rhs_u32, carry_bit] + let mut packed = [F::ZERO; 3]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = (lhs_u64 as u32) as u64; + let rhs = (rhs_u64 as u32) as u64; + let carry = (lhs.wrapping_add(rhs) >> 32) & 1; + packed[0] = F::from_u64(lhs); + packed[1] = F::from_u64(rhs); + packed[2] = F::from_u64(carry); + } + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_prove_verify() { + // Program: + // - ADDI x1, x0, 1 + // - ADDI x2, x1, 2 + // - HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: ADD table, 1 lane. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let add_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 3, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Add, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let add_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ add_lut_inst.d * add_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("ADD Shout z"); + let add_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &add_z); + let add_c = l.commit(&add_Z); + let add_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![add_c], + ..add_lut_inst + }; + let add_lut_wit = LutWitness { mats: vec![add_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(add_lut_inst, add_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + // Sanity: no-shared-bus mode should emit Shout ME(time) claims and fold them. + assert!( + !proof.steps[0].mem.shout_me_claims_time.is_empty(), + "expected Shout ME(time) claims in no-shared-bus mode" + ); + assert!( + !proof.steps[0].shout_time_fold.is_empty(), + "expected shout_time_fold proofs in no-shared-bus mode" + ); + assert!( + proof.steps[0].mem.twist_me_claims_time.is_empty(), + "expected no Twist ME(time) claims when no mem instances are present" + ); + assert!( + proof.steps[0].twist_time_fold.is_empty(), + "expected no twist_time_fold proofs when no mem instances are present" + ); + assert!( + proof.steps[0].mem.val_me_claims.is_empty(), + "expected no val_me_claims when no mem instances are present" + ); + assert!( + proof.steps[0].val_fold.is_empty(), + "expected no val_fold proofs when no mem instances are present" + ); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sll_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sll_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..af5eeb4c --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sll_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,256 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=38 for packed SLL (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let shamt = (rhs_u64 as u32) & 0x1F; + let wide = (lhs as u64) << shamt; + let carry = (wide >> 32) as u32; + + // Packed-key layout (ell_addr=38): + // [lhs_u32, shamt_bits[0..5], carry_bits[0..32]]. + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + for bit in 0..5 { + packed[1 + bit] = if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + for bit in 0..32 { + packed[6 + bit] = if ((carry >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_sll_prove_verify() { + // Program: + // - LUI x1, 1 (x1 = 4096) + // - SLLI x2, x1, 3 (x2 = 32768) + // - SLLI x3, x1, 0 (x3 = 4096) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 1 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sll, + rd: 2, + rs1: 1, + imm: 3, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sll, + rd: 3, + rs1: 1, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 38usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SLL table, 1 lane. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sll).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let sll_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sll, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let sll_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sll_lut_inst.d * sll_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SLL Shout z"); + let sll_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sll_z); + let sll_c = l.commit(&sll_Z); + let sll_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sll_c], + ..sll_lut_inst + }; + let sll_lut_wit = LutWitness { mats: vec![sll_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sll_lut_inst, sll_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sll"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sll"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_slt_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_slt_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..a63704d6 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_slt_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,262 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 37 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=37 for packed SLT (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + + // Signed compare via biased-unsigned compare (flip sign bit). + let lhs_b = lhs ^ 0x8000_0000; + let rhs_b = rhs ^ 0x8000_0000; + let diff = lhs_b.wrapping_sub(rhs_b); + + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(diff as u64); + packed[3] = if lhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[4] = if rhs_sign == 1 { F::ONE } else { F::ZERO }; + for bit in 0..32 { + let b = (diff >> bit) & 1; + packed[5 + bit] = if b == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_slt_prove_verify() { + // Program: + // - LUI x1, 0x80000 (x1 = 0x8000_0000) + // - SLTI x2, x1, 0 (1) + // - SLTI x3, x0, -1 (0) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Slt, + rd: 2, + rs1: 1, + imm: 0, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Slt, + rd: 3, + rs1: 0, + imm: -1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 37usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SLT table, 1 lane. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Slt).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let slt_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 37, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Slt, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let slt_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ slt_lut_inst.d * slt_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SLT Shout z"); + let slt_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &slt_z); + let slt_c = l.commit(&slt_Z); + let slt_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![slt_c], + ..slt_lut_inst + }; + let slt_lut_wit = LutWitness { mats: vec![slt_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(slt_lut_inst, slt_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-slt"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-slt"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sltu_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sltu_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..1b08cae5 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sltu_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,255 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 35 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=35 for packed SLTU (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let diff = lhs.wrapping_sub(rhs); + + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(diff as u64); + for bit in 0..32 { + let b = (diff >> bit) & 1; + packed[3 + bit] = if b == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_sltu_prove_verify() { + // Program: + // - LUI x1, 0 (x1 = 0) + // - LUI x2, 1 (x2 = 4096) + // - SLTU x3, x1, x2 (1) + // - SLTU x4, x2, x1 (0) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0 }, + RiscvInstruction::Lui { rd: 2, imm: 1 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sltu, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sltu, + rd: 4, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 35usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SLTU table, 1 lane. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sltu).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let sltu_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 35, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sltu, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let sltu_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sltu_lut_inst.d * sltu_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SLTU Shout z"); + let sltu_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sltu_z); + let sltu_c = l.commit(&sltu_Z); + let sltu_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sltu_c], + ..sltu_lut_inst + }; + let sltu_lut_wit = LutWitness { mats: vec![sltu_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sltu_lut_inst, sltu_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sltu"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sltu"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sra_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sra_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..e1b4ad45 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sra_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,270 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=38 for packed SRA (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let shamt = (rhs_u64 as u32) & 0x1F; + let sign = (lhs >> 31) & 1; + + let lhs_signed: i64 = if sign == 1 { + (lhs as i64) - (1i64 << 32) + } else { + lhs as i64 + }; + let val_u32 = lane.value[j] as u32; + let val_signed: i64 = (val_u32 as i64) - (sign as i64) * (1i64 << 32); + let pow2: i64 = 1i64 << shamt; + let rem_i64 = lhs_signed - val_signed * pow2; + if rem_i64 < 0 { + return Err("build_shout_only_bus_z: negative SRA remainder".into()); + } + let rem = rem_i64 as u64; + + // Packed-key layout (ell_addr=38): + // [lhs_u32, shamt_bits[0..5], sign_bit, rem_bits[0..31]]. + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + for bit in 0..5 { + packed[1 + bit] = if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + packed[6] = if sign == 1 { F::ONE } else { F::ZERO }; + for bit in 0..31 { + packed[7 + bit] = if ((rem >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_sra_prove_verify() { + // Program: + // - LUI x1, 0x80000 (x1 = 0x8000_0000) + // - SRAI x2, x1, 3 (x2 = 0xF000_0000) + // - SRAI x3, x1, 0 (x3 = 0x8000_0000) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sra, + rd: 2, + rs1: 1, + imm: 3, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sra, + rd: 3, + rs1: 1, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 38usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SRA table, 1 lane. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sra).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let sra_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sra, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let sra_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sra_lut_inst.d * sra_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SRA Shout z"); + let sra_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sra_z); + let sra_c = l.commit(&sra_Z); + let sra_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sra_c], + ..sra_lut_inst + }; + let sra_lut_wit = LutWitness { mats: vec![sra_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sra_lut_inst, sra_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sra"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sra"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_srl_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_srl_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..a41c054b --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_srl_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,260 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=38 for packed SRL (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let shamt = (rhs_u64 as u32) & 0x1F; + let rem: u32 = if shamt == 0 { + 0 + } else { + let mask = (1u64 << shamt) - 1; + ((lhs as u64) & mask) as u32 + }; + + // Packed-key layout (ell_addr=38): + // [lhs_u32, shamt_bits[0..5], rem_bits[0..32]]. + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + for bit in 0..5 { + packed[1 + bit] = if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + for bit in 0..32 { + packed[6 + bit] = if ((rem >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_srl_prove_verify() { + // Program: + // - LUI x1, 1 (x1 = 4096) + // - SRLI x2, x1, 3 (x2 = 512) + // - SRLI x3, x1, 0 (x3 = 4096) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 1 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Srl, + rd: 2, + rs1: 1, + imm: 3, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Srl, + rd: 3, + rs1: 1, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 38usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SRL table, 1 lane. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Srl).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let srl_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Srl, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let srl_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ srl_lut_inst.d * srl_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SRL Shout z"); + let srl_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &srl_z); + let srl_c = l.commit(&srl_Z); + let srl_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![srl_c], + ..srl_lut_inst + }; + let srl_lut_wit = LutWitness { mats: vec![srl_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(srl_lut_inst, srl_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-srl"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-srl"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sub_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sub_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..50ef76ea --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_sub_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,242 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 3 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=3 for packed SUB (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_shout_only_bus_z: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + // Packed-key layout: [lhs_u32, rhs_u32, borrow_bit] + let mut packed = [F::ZERO; 3]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs_u32 = lhs_u64 as u32; + let rhs_u32 = rhs_u64 as u32; + let borrow = (lhs_u32 < rhs_u32) as u64; + packed[0] = F::from_u64(lhs_u32 as u64); + packed[1] = F::from_u64(rhs_u32 as u64); + packed[2] = F::from_u64(borrow); + } + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_sub_prove_verify() { + // Program: + // - LUI x1, 0 (x1 = 0) + // - LUI x2, 1 (x2 = 4096) + // - SUB x3, x1, x2 (borrow=1) + // - SUB x4, x2, x1 (borrow=0) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0 }, + RiscvInstruction::Lui { rd: 2, imm: 1 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sub, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sub, + rd: 4, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SUB table, 1 lane. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sub).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let sub_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 3, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sub, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let sub_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sub_lut_inst.d * sub_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SUB Shout z"); + let sub_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sub_z); + let sub_c = l.commit(&sub_Z); + let sub_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sub_c], + ..sub_lut_inst + }; + let sub_lut_wit = LutWitness { mats: vec![sub_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sub_lut_inst, sub_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sub"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sub"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_xor_no_shared_cpu_bus_e2e.rs b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_xor_no_shared_cpu_bus_e2e.rs new file mode 100644 index 00000000..23088806 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/e2e_ops/riscv_trace_shout_xor_no_shared_cpu_bus_e2e.rs @@ -0,0 +1,294 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_ccs::Mat; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn plan_paged_shout_addr(m: usize, m_in: usize, t: usize, ell_addr: usize, lanes: usize) -> Result, String> { + if t == 0 { + return Err("plan_paged_shout_addr: t must be >= 1".into()); + } + if m_in > m { + return Err(format!("plan_paged_shout_addr: m_in={m_in} > m={m}")); + } + if lanes == 0 { + return Err("plan_paged_shout_addr: lanes must be >= 1".into()); + } + if ell_addr == 0 { + return Err("plan_paged_shout_addr: ell_addr must be >= 1".into()); + } + + // Match `neo_fold::memory_sidecar::shout_paging::plan_shout_addr_pages`. + let avail = m - m_in; + let max_bus_cols_total = avail / t; + let per_lane_capacity = max_bus_cols_total / lanes; + if per_lane_capacity < 3 { + return Err(format!( + "plan_paged_shout_addr: insufficient per-lane capacity (need >=3 cols per lane, have {per_lane_capacity})" + )); + } + let max_addr_cols_per_page = per_lane_capacity - 2; + + let mut out = Vec::new(); + let mut remaining = ell_addr; + while remaining > 0 { + let take = remaining.min(max_addr_cols_per_page); + out.push(take); + remaining -= take; + } + Ok(out) +} + +fn build_paged_shout_only_bus_zs( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result>, String> { + if x_prefix.len() != m_in { + return Err(format!( + "build_paged_shout_only_bus_zs: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_paged_shout_only_bus_zs: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let page_ell_addrs = plan_paged_shout_addr(m, m_in, t, ell_addr, lanes)?; + let mut out: Vec> = Vec::with_capacity(page_ell_addrs.len()); + + let mut bit_base: usize = 0; + for (page_idx, &page_ell_addr) in page_ell_addrs.iter().enumerate() { + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((page_ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_paged_shout_only_bus_zs: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_paged_shout_only_bus_zs: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + for (local_idx, col_id) in cols.addr_bits.clone().enumerate() { + let bit_idx = bit_base + .checked_add(local_idx) + .ok_or_else(|| "build_paged_shout_only_bus_zs: bit index overflow".to_string())?; + if bit_idx >= 64 { + return Err(format!( + "build_paged_shout_only_bus_zs: bit_idx={bit_idx} out of range for u64 key (page_idx={page_idx})" + )); + } + let b = if has { (lane.key[j] >> bit_idx) & 1 } else { 0 }; + z[bus.bus_cell(col_id, j)] = if b == 1 { F::ONE } else { F::ZERO }; + } + } + } + + out.push(z); + bit_base = bit_base + .checked_add(page_ell_addr) + .ok_or_else(|| "build_paged_shout_only_bus_zs: bit_base overflow".to_string())?; + } + if bit_base != ell_addr { + return Err(format!( + "build_paged_shout_only_bus_zs: paging mismatch (got bit_base={bit_base}, expected ell_addr={ell_addr})" + )); + } + + Ok(out) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_xor_paged_prove_verify() { + // Program: + // - LUI x1, 0x80000 (x1 = 0x80000000) + // - XORI x2, x0, 1 (x2 = 1) + // - XOR x3, x1, x2 (x3 = 0x80000001) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Xor, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: XOR table, 1 lane, bit-addressed (ell_addr=64) → paged across multiple mats. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Xor).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let ell_addr = 64usize; + let lanes = 1usize; + let page_ell_addrs = plan_paged_shout_addr(ccs.m, layout.m_in, t, ell_addr, lanes).expect("paging plan"); + + let paged_zs = build_paged_shout_only_bus_zs(ccs.m, layout.m_in, t, ell_addr, lanes, &shout_lanes, &x) + .expect("XOR Shout paged z"); + assert_eq!(paged_zs.len(), page_ell_addrs.len(), "z/page drift"); + + let mut mats: Vec> = Vec::with_capacity(paged_zs.len()); + let mut comms: Vec = Vec::with_capacity(paged_zs.len()); + for z in paged_zs { + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + comms.push(l.commit(&Z)); + mats.push(Z); + } + + let xor_lut_inst = LutInstance:: { + table_id: shout_table_ids[0], + comms, + k: 0, + d: ell_addr, + n_side: 2, + steps: t, + lanes, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcode { + opcode: RiscvOpcode::Xor, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let xor_lut_wit = LutWitness { mats }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(xor_lut_inst, xor_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-xor-paged"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + assert_eq!( + proof.steps[0].mem.shout_me_claims_time.len(), + page_ell_addrs.len(), + "expected 1 Shout ME(time) claim per paging mat" + ); + assert_eq!( + proof.steps[0].shout_time_fold.len(), + page_ell_addrs.len(), + "expected 1 shout_time_fold proof per paging mat" + ); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-xor-paged"); + let _ = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect("verify"); +} diff --git a/crates/neo-fold/tests/implicit_shout_table_spec_tests.rs b/crates/neo-fold/tests/suites/trace_shout/implicit_shout_table_spec_tests.rs similarity index 95% rename from crates/neo-fold/tests/implicit_shout_table_spec_tests.rs rename to crates/neo-fold/tests/suites/trace_shout/implicit_shout_table_spec_tests.rs index 1ae62864..b59e7bc9 100644 --- a/crates/neo-fold/tests/implicit_shout_table_spec_tests.rs +++ b/crates/neo-fold/tests/suites/trace_shout/implicit_shout_table_spec_tests.rs @@ -43,16 +43,7 @@ impl SModuleHomomorphism for DummyCommit { } fn default_mixers() -> Mixers { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } + crate::common_setup::default_mixers() } fn create_identity_ccs(n: usize) -> CcsStructure { @@ -123,6 +114,7 @@ fn absorb_step_memory_binds_table_spec() { let make_step = |opcode: RiscvOpcode| StepInstanceBundle:: { mcs_inst: dummy_mcs.clone(), lut_insts: vec![LutInstance { + table_id: 0, comms: Vec::new(), k: 0, d: 64, @@ -132,6 +124,8 @@ fn absorb_step_memory_binds_table_spec() { ell: 1, table_spec: Some(LutTableSpec::RiscvOpcode { opcode, xlen: 32 }), table: vec![], + addr_group: None, + selector_group: None, }], mem_insts: vec![], _phantom: PhantomData, @@ -165,6 +159,7 @@ fn route_a_shout_implicit_table_spec_verifies() { let out = compute_op(opcode, rs1, rs2, xlen); let inst = LutInstance:: { + table_id: 0, comms: Vec::new(), k: 0, d: 64, @@ -174,6 +169,8 @@ fn route_a_shout_implicit_table_spec_verifies() { ell: 1, table_spec: Some(LutTableSpec::RiscvOpcode { opcode, xlen }), table: vec![], + addr_group: None, + selector_group: None, }; let wit = LutWitness { mats: Vec::new() }; @@ -195,7 +192,7 @@ fn route_a_shout_implicit_table_spec_verifies() { let mut tr_prove = Poseidon2Transcript::new(b"implicit-shout-table-spec"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -210,7 +207,7 @@ fn route_a_shout_implicit_table_spec_verifies() { let mut tr_verify = Poseidon2Transcript::new(b"implicit-shout-table-spec"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -229,7 +226,7 @@ fn route_a_shout_implicit_table_spec_verifies() { xlen, }); let err = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify_bad, ¶ms, &ccs, @@ -255,6 +252,7 @@ fn route_a_shout_implicit_identity_u32_table_spec_verifies() { let out = addr; let inst = LutInstance:: { + table_id: 0, comms: Vec::new(), k: 0, d: 32, @@ -264,6 +262,8 @@ fn route_a_shout_implicit_identity_u32_table_spec_verifies() { ell: 1, table_spec: Some(LutTableSpec::IdentityU32), table: vec![], + addr_group: None, + selector_group: None, }; let wit = LutWitness { mats: Vec::new() }; @@ -285,7 +285,7 @@ fn route_a_shout_implicit_identity_u32_table_spec_verifies() { let mut tr_prove = Poseidon2Transcript::new(b"implicit-shout-identity-u32-table-spec"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -300,7 +300,7 @@ fn route_a_shout_implicit_identity_u32_table_spec_verifies() { let mut tr_verify = Poseidon2Transcript::new(b"implicit-shout-identity-u32-table-spec"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -316,7 +316,7 @@ fn route_a_shout_implicit_identity_u32_table_spec_verifies() { let mut steps_public_bad = [StepInstanceBundle::from(&step_bundle)]; steps_public_bad[0].lut_insts[0].table_spec = None; let err = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify_bad, ¶ms, &ccs, diff --git a/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/mod.rs b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/mod.rs new file mode 100644 index 00000000..e6f631df --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/mod.rs @@ -0,0 +1,4 @@ +mod riscv_trace_shout_event_table_no_shared_cpu_bus_linkage_redteam; +mod riscv_trace_shout_no_shared_cpu_bus_linkage_redteam; +mod riscv_trace_shout_sub_no_shared_cpu_bus_linkage_redteam; +mod riscv_trace_shout_xor_no_shared_cpu_bus_linkage_redteam; diff --git a/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_event_table_no_shared_cpu_bus_linkage_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_event_table_no_shared_cpu_bus_linkage_redteam.rs new file mode 100644 index 00000000..bed583b0 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_event_table_no_shared_cpu_bus_linkage_redteam.rs @@ -0,0 +1,226 @@ +#![allow(non_snake_case)] + +#[path = "../../../common/riscv_shout_event_table_packed.rs"] +mod event_table_packed; + +use std::collections::BTreeMap; +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::{Rv32ExecTable, Rv32ShoutEventRow, Rv32ShoutEventTable}; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn flip_time_bit0_for_all_events( + z: &mut [F], + m: usize, + m_in: usize, + steps: usize, + ell_addr: usize, +) -> Result<(), String> { + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + steps, + core::iter::once((ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + let cols = &bus.shout_cols[0].lanes[0]; + let time_bit0_col = cols.addr_bits.start; + for j in 0..steps { + let idx = bus.bus_cell(time_bit0_col, j); + z[idx] = if z[idx] == F::ZERO { F::ONE } else { F::ZERO }; + } + Ok(()) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_event_table_linkage_redteam() { + // Minimal program; we tamper with an event-table time bit so the hash linkage fails. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 5, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 7, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 64).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Event table extraction. + let event_table = Rv32ShoutEventTable::from_exec_table(&exec).expect("Rv32ShoutEventTable::from_exec_table"); + assert!(!event_table.rows.is_empty(), "expected non-empty Shout event table"); + + // Group by shout_id (stable) and build one event-table packed instance per opcode. + let mut by_id: BTreeMap)> = BTreeMap::new(); + for row in event_table.rows.iter() { + let opcode = row + .opcode + .ok_or_else(|| format!("missing opcode for shout_id={}", row.shout_id)) + .unwrap(); + let entry = by_id + .entry(row.shout_id) + .or_insert_with(|| (opcode, Vec::new())); + if entry.0 != opcode { + panic!( + "opcode mismatch for shout_id={}: {:?} vs {:?}", + row.shout_id, entry.0, opcode + ); + } + entry.1.push(row.clone()); + } + + let ell_n = event_table_packed::ell_n_from_ccs_n(ccs.n); + assert!(ell_n >= 3, "event-table packed requires ell_n>=3 (got ell_n={ell_n})"); + assert!(ell_n <= 64, "event-table packed requires ell_n<=64 (got ell_n={ell_n})"); + + let mut lut_instances: Vec<(LutInstance, LutWitness)> = Vec::new(); + let mut did_tamper = false; + for (_shout_id, (opcode, rows)) in by_id.into_iter() { + let steps = rows.len(); + let base_d = + event_table_packed::rv32_packed_base_d(opcode).unwrap_or_else(|e| panic!("opcode {opcode:?}: {e}")); + let d = ell_n + base_d; + let ell_addr = d; + + let mut z = + event_table_packed::build_shout_event_table_bus_z(ccs.m, layout.m_in, steps, ell_n, opcode, &rows, &x) + .unwrap_or_else(|e| panic!("event-table z build failed (opcode={opcode:?}): {e}")); + + // Tamper with the time-bit prefix of the first instance only (keeps booleanity). + if !did_tamper { + flip_time_bit0_for_all_events(&mut z, ccs.m, layout.m_in, steps, ell_addr).expect("tamper time bit"); + did_tamper = true; + } + + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + let c = l.commit(&Z); + + let inst = LutInstance:: { + table_id: 0, + comms: vec![c], + k: 0, + d, + n_side: 2, + steps, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodeEventTablePacked { + opcode, + xlen: 32, + time_bits: ell_n, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let wit = LutWitness { mats: vec![Z] }; + lut_instances.push((inst, wit)); + } + assert!(did_tamper, "expected to tamper at least one Shout instance"); + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances, + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-event-table-packed-redteam"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-event-table-packed-redteam"); + let err = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("verification should fail due to event-table hash linkage mismatch"); + + // Keep the assertion stable but informative. + let _ = err; +} diff --git a/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_no_shared_cpu_bus_linkage_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_no_shared_cpu_bus_linkage_redteam.rs new file mode 100644 index 00000000..ea6c1191 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_no_shared_cpu_bus_linkage_redteam.rs @@ -0,0 +1,265 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 3 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=3 for packed ADD (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_shout_only_bus_z: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + // Packed-key layout: [lhs_u32, rhs_u32, carry_bit] + let mut packed = [F::ZERO; 3]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = (lhs_u64 as u32) as u64; + let rhs = (rhs_u64 as u32) as u64; + let carry = (lhs.wrapping_add(rhs) >> 32) & 1; + packed[0] = F::from_u64(lhs); + packed[1] = F::from_u64(rhs); + packed[2] = F::from_u64(carry); + } + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + + Ok(z) +} + +fn run_no_shared_cpu_bus_shout_linkage_redteam(select_tamper_col: FSelect) +where + FSelect: Fn(&Rv32TraceCcsLayout) -> usize, +{ + // Program: + // - ADDI x1, x0, 1 + // - ADDI x2, x1, 2 + // - HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Find an active row with a lookup, then tamper with `shout_val` in the trace witness. + let mut tamper_row: Option = None; + for row in 0..layout.t { + let idx = layout.cell(layout.trace.shout_has_lookup, row); + if w[idx - layout.m_in] == F::ONE { + tamper_row = Some(row); + break; + } + } + let row = tamper_row.expect("expected at least one Shout lookup in the trace"); + let tamper_col = select_tamper_col(&layout); + let tamper_idx = layout.cell(tamper_col, row); + w[tamper_idx - layout.m_in] += F::ONE; + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (committed to the tampered witness). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: ADD table, 1 lane (honest sidecar witness). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let add_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 3, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Add, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let add_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ add_lut_inst.d * add_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("ADD Shout z"); + let add_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &add_z); + let add_c = l.commit(&add_Z); + let add_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![add_c], + ..add_lut_inst + }; + let add_lut_wit = LutWitness { mats: vec![add_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(add_lut_inst, add_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // In no-shared mode, shout linkage is validated during Route-A verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-redteam"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("tampered trace shout_val should still admit a proof object before verify checks"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered trace shout linkage must fail during no-shared verification"); +} + +#[test] +fn riscv_trace_no_shared_cpu_bus_shout_val_linkage_redteam() { + run_no_shared_cpu_bus_shout_linkage_redteam(|layout| layout.trace.shout_val); +} + +#[test] +fn riscv_trace_no_shared_cpu_bus_shout_lhs_linkage_redteam() { + run_no_shared_cpu_bus_shout_linkage_redteam(|layout| layout.trace.shout_lhs); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_sub_no_shared_cpu_bus_linkage_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_sub_no_shared_cpu_bus_linkage_redteam.rs new file mode 100644 index 00000000..c237f9bf --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_sub_no_shared_cpu_bus_linkage_redteam.rs @@ -0,0 +1,249 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 3 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=3 for packed SUB (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_shout_only_bus_z: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + // Packed-key layout: [lhs_u32, rhs_u32, borrow_bit] + let mut packed = [F::ZERO; 3]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs_u32 = lhs_u64 as u32; + let rhs_u32 = rhs_u64 as u32; + let borrow = (lhs_u32 < rhs_u32) as u64; + packed[0] = F::from_u64(lhs_u32 as u64); + packed[1] = F::from_u64(rhs_u32 as u64); + packed[2] = F::from_u64(borrow); + } + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_no_shared_cpu_bus_shout_sub_linkage_redteam() { + // Program: + // - LUI x1, 0 (x1 = 0) + // - LUI x2, 1 (x2 = 4096) + // - SUB x3, x1, x2 + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0 }, + RiscvInstruction::Lui { rd: 2, imm: 1 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sub, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Find an active row with a lookup, then tamper with `shout_val` in the trace witness. + let mut tamper_row: Option = None; + for row in 0..layout.t { + let idx = layout.cell(layout.trace.shout_has_lookup, row); + if w[idx - layout.m_in] == F::ONE { + tamper_row = Some(row); + break; + } + } + let row = tamper_row.expect("expected at least one Shout lookup in the trace"); + let val_idx = layout.cell(layout.trace.shout_val, row); + w[val_idx - layout.m_in] += F::ONE; + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (committed to the tampered witness). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SUB table, 1 lane (honest sidecar witness). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sub).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let sub_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 3, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sub, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let sub_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sub_lut_inst.d * sub_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SUB Shout z"); + let sub_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sub_z); + let sub_c = l.commit(&sub_Z); + let sub_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sub_c], + ..sub_lut_inst + }; + let sub_lut_wit = LutWitness { mats: vec![sub_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sub_lut_inst, sub_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // In no-shared mode, shout linkage is validated during Route-A verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sub-redteam"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("tampered trace shout_val should still admit a proof object before verify checks"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sub-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered trace shout_val must fail during no-shared shout linkage verification"); +} diff --git a/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_xor_no_shared_cpu_bus_linkage_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_xor_no_shared_cpu_bus_linkage_redteam.rs new file mode 100644 index 00000000..03dc2959 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/linkage_redteam/riscv_trace_shout_xor_no_shared_cpu_bus_linkage_redteam.rs @@ -0,0 +1,424 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_ccs::Mat; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn plan_paged_shout_addr(m: usize, m_in: usize, t: usize, ell_addr: usize, lanes: usize) -> Result, String> { + if t == 0 { + return Err("plan_paged_shout_addr: t must be >= 1".into()); + } + if m_in > m { + return Err(format!("plan_paged_shout_addr: m_in={m_in} > m={m}")); + } + if lanes == 0 { + return Err("plan_paged_shout_addr: lanes must be >= 1".into()); + } + if ell_addr == 0 { + return Err("plan_paged_shout_addr: ell_addr must be >= 1".into()); + } + + let avail = m - m_in; + let max_bus_cols_total = avail / t; + let per_lane_capacity = max_bus_cols_total / lanes; + if per_lane_capacity < 3 { + return Err(format!( + "plan_paged_shout_addr: insufficient per-lane capacity (need >=3 cols per lane, have {per_lane_capacity})" + )); + } + let max_addr_cols_per_page = per_lane_capacity - 2; + + let mut out = Vec::new(); + let mut remaining = ell_addr; + while remaining > 0 { + let take = remaining.min(max_addr_cols_per_page); + out.push(take); + remaining -= take; + } + Ok(out) +} + +fn build_paged_shout_only_bus_zs( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result>, String> { + if x_prefix.len() != m_in { + return Err(format!( + "build_paged_shout_only_bus_zs: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_paged_shout_only_bus_zs: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let page_ell_addrs = plan_paged_shout_addr(m, m_in, t, ell_addr, lanes)?; + let mut out: Vec> = Vec::with_capacity(page_ell_addrs.len()); + + let mut bit_base: usize = 0; + for (page_idx, &page_ell_addr) in page_ell_addrs.iter().enumerate() { + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((page_ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_paged_shout_only_bus_zs: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_paged_shout_only_bus_zs: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + for (local_idx, col_id) in cols.addr_bits.clone().enumerate() { + let bit_idx = bit_base + .checked_add(local_idx) + .ok_or_else(|| "build_paged_shout_only_bus_zs: bit index overflow".to_string())?; + if bit_idx >= 64 { + return Err(format!( + "build_paged_shout_only_bus_zs: bit_idx={bit_idx} out of range for u64 key (page_idx={page_idx})" + )); + } + let b = if has { (lane.key[j] >> bit_idx) & 1 } else { 0 }; + z[bus.bus_cell(col_id, j)] = if b == 1 { F::ONE } else { F::ZERO }; + } + } + } + + out.push(z); + bit_base = bit_base + .checked_add(page_ell_addr) + .ok_or_else(|| "build_paged_shout_only_bus_zs: bit_base overflow".to_string())?; + } + if bit_base != ell_addr { + return Err(format!( + "build_paged_shout_only_bus_zs: paging mismatch (got bit_base={bit_base}, expected ell_addr={ell_addr})" + )); + } + + Ok(out) +} + +#[test] +fn riscv_trace_no_shared_cpu_bus_shout_xor_paging_linkage_redteam() { + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Xor, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment. + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: XOR table, 1 lane, bit-addressed (ell_addr=64) paged across mats. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Xor).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let ell_addr = 64usize; + let lanes = 1usize; + let page_ell_addrs = plan_paged_shout_addr(ccs.m, layout.m_in, t, ell_addr, lanes).expect("paging plan"); + + let paged_zs = build_paged_shout_only_bus_zs(ccs.m, layout.m_in, t, ell_addr, lanes, &shout_lanes, &x) + .expect("XOR Shout paged z"); + assert_eq!(paged_zs.len(), page_ell_addrs.len(), "z/page drift"); + + let mut mats: Vec> = Vec::with_capacity(paged_zs.len()); + let mut comms: Vec = Vec::with_capacity(paged_zs.len()); + for z in paged_zs { + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + comms.push(l.commit(&Z)); + mats.push(Z); + } + + let xor_lut_inst = LutInstance:: { + table_id: shout_table_ids[0], + comms, + k: 0, + d: ell_addr, + n_side: 2, + steps: t, + lanes, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcode { + opcode: RiscvOpcode::Xor, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let xor_lut_wit = LutWitness { mats }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(xor_lut_inst, xor_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let mut steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-xor-paged-redteam"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + // Redteam: tamper a LUT commitment in the instance; verifier must reject. + if steps_instance[0].lut_insts[0].comms.len() > 1 { + steps_instance[0].lut_insts[0].comms[1] = steps_instance[0].lut_insts[0].comms[0].clone(); + } else { + steps_instance[0].lut_insts[0].comms[0] = steps_instance[0].mcs_inst.c.clone(); + } + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-xor-paged-redteam"); + let res = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ); + assert!(res.is_err(), "expected verification failure after paging-commit tamper"); +} + +#[test] +fn riscv_trace_no_shared_cpu_bus_shout_table_id_mismatch_redteam() { + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Xor, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Extract real XOR events, but intentionally prove them under OR table semantics. + // For this operand set XOR==OR on every active lookup row, so has/val/lhs/rhs linkage + // alone cannot distinguish the wrong table selection. + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Xor).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let ell_addr = 64usize; + let lanes = 1usize; + let page_ell_addrs = plan_paged_shout_addr(ccs.m, layout.m_in, t, ell_addr, lanes).expect("paging plan"); + + let paged_zs = build_paged_shout_only_bus_zs(ccs.m, layout.m_in, t, ell_addr, lanes, &shout_lanes, &x) + .expect("OR-by-XOR-event paged z"); + assert_eq!(paged_zs.len(), page_ell_addrs.len(), "z/page drift"); + + let mut mats: Vec> = Vec::with_capacity(paged_zs.len()); + let mut comms: Vec = Vec::with_capacity(paged_zs.len()); + for z in paged_zs { + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + comms.push(l.commit(&Z)); + mats.push(Z); + } + + let wrong_lut_inst = LutInstance:: { + table_id: RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Or).0, + comms, + k: 0, + d: ell_addr, + n_side: 2, + steps: t, + lanes, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcode { + opcode: RiscvOpcode::Or, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let wrong_lut_wit = LutWitness { mats }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(wrong_lut_inst, wrong_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-table-id-redteam"); + let proof = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) + .expect("prove"); + + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-table-id-redteam"); + let res = fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ); + // Legacy no-shared mode does not carry decode-lookup openings, so table-id linkage + // is not enforced here after removing `trace.shout_table_id`. + // Shared trace-wiring mode enforces table-id linkage via decode-backed openings. + assert!( + res.is_ok(), + "legacy no-shared path should accept this table-id aliasing case without decode linkage" + ); +} diff --git a/crates/neo-fold/tests/mixed_shout_table_sizes.rs b/crates/neo-fold/tests/suites/trace_shout/mixed_shout_table_sizes.rs similarity index 96% rename from crates/neo-fold/tests/mixed_shout_table_sizes.rs rename to crates/neo-fold/tests/suites/trace_shout/mixed_shout_table_sizes.rs index 5e2ecffe..09d6bf7d 100644 --- a/crates/neo-fold/tests/mixed_shout_table_sizes.rs +++ b/crates/neo-fold/tests/suites/trace_shout/mixed_shout_table_sizes.rs @@ -48,16 +48,7 @@ fn setup_ajtai_pp(m: usize, seed: u64) -> AjtaiSModule { } fn default_mixers() -> Mixers { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } + crate::common_setup::default_mixers() } fn create_identity_ccs(n: usize) -> CcsStructure { @@ -90,6 +81,7 @@ fn make_shout_instance( let ell = table.n_side.trailing_zeros() as usize; ( neo_memory::witness::LutInstance { + table_id: table.table_id, comms: Vec::new(), k: table.k, d: table.d, @@ -99,6 +91,8 @@ fn make_shout_instance( ell, table_spec: None, table: table.content.clone(), + addr_group: None, + selector_group: None, }, neo_memory::witness::LutWitness { mats: Vec::new() }, ) @@ -228,7 +222,7 @@ fn mixed_shout_tables_16_and_256_entries_same_step() { let mut tr_prove = Poseidon2Transcript::new(b"mixed-shout-sizes"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -255,7 +249,7 @@ fn mixed_shout_tables_16_and_256_entries_same_step() { let mut tr_verify = Poseidon2Transcript::new(b"mixed-shout-sizes"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, diff --git a/crates/neo-fold/tests/suites/trace_shout/mod.rs b/crates/neo-fold/tests/suites/trace_shout/mod.rs new file mode 100644 index 00000000..84bd6e84 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/mod.rs @@ -0,0 +1,8 @@ +mod implicit_shout_table_spec_tests; +mod mixed_shout_table_sizes; +mod multi_table_shout_tests; +mod range_check_lookup_tests; +mod shout_identity_u32_range_check; +mod shout_multi_lookup_implicit_table_spec; +mod shout_multi_lookup_per_step; +mod shout_padded_binary_table; diff --git a/crates/neo-fold/tests/multi_table_shout_tests.rs b/crates/neo-fold/tests/suites/trace_shout/multi_table_shout_tests.rs similarity index 97% rename from crates/neo-fold/tests/multi_table_shout_tests.rs rename to crates/neo-fold/tests/suites/trace_shout/multi_table_shout_tests.rs index c4679a57..a15ac24d 100644 --- a/crates/neo-fold/tests/multi_table_shout_tests.rs +++ b/crates/neo-fold/tests/suites/trace_shout/multi_table_shout_tests.rs @@ -53,16 +53,7 @@ fn setup_ajtai_pp(m: usize, seed: u64) -> AjtaiSModule { } fn default_mixers() -> Mixers { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } + crate::common_setup::default_mixers() } fn create_identity_ccs(n: usize) -> CcsStructure { @@ -95,6 +86,7 @@ fn make_shout_instance( let ell = table.n_side.trailing_zeros() as usize; ( neo_memory::witness::LutInstance { + table_id: table.table_id, comms: Vec::new(), k: table.k, d: table.d, @@ -104,6 +96,8 @@ fn make_shout_instance( ell, table_spec: None, table: table.content.clone(), + addr_group: None, + selector_group: None, }, neo_memory::witness::LutWitness { mats: Vec::new() }, ) @@ -243,7 +237,7 @@ fn multi_table_shout_two_tables() { let mut tr_prove = Poseidon2Transcript::new(b"multi-table-two"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -258,7 +252,7 @@ fn multi_table_shout_two_tables() { let mut tr_verify = Poseidon2Transcript::new(b"multi-table-two"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -356,7 +350,7 @@ fn multi_table_shout_three_tables_interleaved() { let mut tr_prove = Poseidon2Transcript::new(b"multi-table-three"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -371,7 +365,7 @@ fn multi_table_shout_three_tables_interleaved() { let mut tr_verify = Poseidon2Transcript::new(b"multi-table-three"); let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -440,7 +434,7 @@ fn multi_table_wrong_table_value_fails() { let mut tr_prove = Poseidon2Transcript::new(b"multi-table-wrong"); let proof_result = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -459,7 +453,7 @@ fn multi_table_wrong_table_value_fails() { let mut tr_verify = Poseidon2Transcript::new(b"multi-table-wrong"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let verify_result = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -548,7 +542,7 @@ fn multi_table_optional_lookups() { let mut tr_prove = Poseidon2Transcript::new(b"multi-table-optional"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -563,7 +557,7 @@ fn multi_table_optional_lookups() { let mut tr_verify = Poseidon2Transcript::new(b"multi-table-optional"); let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, diff --git a/crates/neo-fold/tests/range_check_lookup_tests.rs b/crates/neo-fold/tests/suites/trace_shout/range_check_lookup_tests.rs similarity index 96% rename from crates/neo-fold/tests/range_check_lookup_tests.rs rename to crates/neo-fold/tests/suites/trace_shout/range_check_lookup_tests.rs index ec3e9a6c..8213ace4 100644 --- a/crates/neo-fold/tests/range_check_lookup_tests.rs +++ b/crates/neo-fold/tests/suites/trace_shout/range_check_lookup_tests.rs @@ -51,16 +51,7 @@ fn setup_ajtai_pp(m: usize, seed: u64) -> AjtaiSModule { } fn default_mixers() -> Mixers { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } + crate::common_setup::default_mixers() } fn create_identity_ccs(n: usize) -> CcsStructure { @@ -104,6 +95,7 @@ fn make_shout_instance( let ell = table.n_side.trailing_zeros() as usize; ( neo_memory::witness::LutInstance { + table_id: table.table_id, comms: Vec::new(), k: table.k, d: table.d, @@ -113,6 +105,8 @@ fn make_shout_instance( ell, table_spec: None, table: table.content.clone(), + addr_group: None, + selector_group: None, }, neo_memory::witness::LutWitness { mats: Vec::new() }, ) @@ -220,7 +214,7 @@ fn range_check_4bit_valid() { let mut tr_prove = Poseidon2Transcript::new(b"range-4bit-valid"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -235,7 +229,7 @@ fn range_check_4bit_valid() { let mut tr_verify = Poseidon2Transcript::new(b"range-4bit-valid"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -275,7 +269,7 @@ fn range_check_4bit_invalid_value_fails() { let mut tr_prove = Poseidon2Transcript::new(b"range-4bit-invalid"); let proof_result = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -291,7 +285,7 @@ fn range_check_4bit_invalid_value_fails() { let steps_public = [StepInstanceBundle::from(&step_bundle)]; assert!( fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -360,7 +354,7 @@ fn range_check_nibble_decomposition() { let mut tr_prove = Poseidon2Transcript::new(b"range-nibble-decomp"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -375,7 +369,7 @@ fn range_check_nibble_decomposition() { let mut tr_verify = Poseidon2Transcript::new(b"range-nibble-decomp"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -457,7 +451,7 @@ fn range_check_combined_with_addition() { let mut tr_prove = Poseidon2Transcript::new(b"range-with-addition"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -472,7 +466,7 @@ fn range_check_combined_with_addition() { let mut tr_verify = Poseidon2Transcript::new(b"range-with-addition"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -514,7 +508,7 @@ fn range_check_wrong_value_claimed_fails() { let mut tr_prove = Poseidon2Transcript::new(b"range-wrong-value-claimed"); let proof_result = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -530,7 +524,7 @@ fn range_check_wrong_value_claimed_fails() { let steps_public = [StepInstanceBundle::from(&step_bundle)]; assert!( fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -582,7 +576,7 @@ fn range_check_boundary_values() { let mut tr_prove = Poseidon2Transcript::new(b"range-boundary"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -597,7 +591,7 @@ fn range_check_boundary_values() { let mut tr_verify = Poseidon2Transcript::new(b"range-boundary"); let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/mod.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/mod.rs new file mode 100644 index 00000000..c9224883 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/mod.rs @@ -0,0 +1,13 @@ +mod riscv_trace_shout_bitwise_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_div_rem_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_divu_remu_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_eq_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_mul_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_mulh_mulhsu_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_mulhu_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_sll_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_slt_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_sltu_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_sra_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_srl_no_shared_cpu_bus_semantics_redteam; +mod riscv_trace_shout_sub_no_shared_cpu_bus_semantics_redteam; diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_bitwise_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_bitwise_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..a352bc55 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_bitwise_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,272 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z_packed_bitwise( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lane_data: &neo_memory::riscv::trace::ShoutLaneOverTime, + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 34 { + return Err(format!( + "build_shout_only_bus_z_packed_bitwise: expected ell_addr=34 (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z_packed_bitwise: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.has_lookup.len() != t { + return Err("build_shout_only_bus_z_packed_bitwise: lane length mismatch".into()); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z_packed_bitwise: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + for j in 0..t { + let has = lane_data.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane_data.value[j]) } else { F::ZERO }; + + let mut packed = [F::ZERO; 34]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane_data.key[j] as u128); + let lhs_u32 = lhs_u64 as u32; + let rhs_u32 = rhs_u64 as u32; + + packed[0] = F::from_u64(lhs_u32 as u64); + packed[1] = F::from_u64(rhs_u32 as u64); + + for i in 0..16usize { + let a = (lhs_u32 >> (2 * i)) & 3; + let b = (rhs_u32 >> (2 * i)) & 3; + packed[2 + i] = F::from_u64(a as u64); + packed[18 + i] = F::from_u64(b as u64); + } + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_bitwise_packed_semantics_redteam() { + // Same program as the e2e test; tamper a single packed digit. + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Xor, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Or, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::And, + rd: 4, + rs1: 3, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 34usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + let t = exec.rows.len(); + let shout_table_ids = vec![ + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::And).0, + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Or).0, + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Xor).0, + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Andn).0, + ]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + assert_eq!(shout_lanes.len(), 4); + + // Tamper the OR packed witness: flip lhs_digit[0] at the first OR lookup row. + let or_lane = &shout_lanes[1]; + let j = or_lane + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one OR lookup"); + + let mut lut_instances = Vec::new(); + for (idx, opcode) in [RiscvOpcode::And, RiscvOpcode::Or, RiscvOpcode::Xor, RiscvOpcode::Andn] + .into_iter() + .enumerate() + { + let inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 34, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { opcode, xlen: 32 }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut z = + build_shout_only_bus_z_packed_bitwise(ccs.m, layout.m_in, t, inst.d * inst.ell, &shout_lanes[idx], &x) + .expect("packed bitwise z"); + + if opcode == RiscvOpcode::Or { + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 34usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let lhs_digit0_col_id = cols + .addr_bits + .clone() + .nth(2) + .expect("expected lhs_digit0 at addr_bits[2]"); + let cell = bus.bus_cell(lhs_digit0_col_id, j); + z[cell] = if z[cell] == F::ZERO { F::ONE } else { F::ZERO }; + } + + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + let c = l.commit(&Z); + let inst = LutInstance:: { comms: vec![c], ..inst }; + let wit = LutWitness { mats: vec![Z] }; + lut_instances.push((inst, wit)); + } + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances, + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // The prover may either: + // - reject because the tampered witness no longer satisfies the protocol invariants, or + // - emit a proof that fails verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-bitwise-packed-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-bitwise-packed-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed bitwise digit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_div_rem_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_div_rem_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..7bc8e57a --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_div_rem_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,721 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::{Field, PrimeCharacteristicRing}; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn div_signed(lhs: u32, rhs: u32) -> u32 { + let lhs_i = lhs as i32; + let rhs_i = rhs as i32; + if rhs_i == 0 { + return u32::MAX; + } + if lhs_i == i32::MIN && rhs_i == -1 { + return lhs; + } + (lhs_i / rhs_i) as u32 +} + +fn rem_signed(lhs: u32, rhs: u32) -> u32 { + let lhs_i = lhs as i32; + let rhs_i = rhs as i32; + if rhs_i == 0 { + return lhs; + } + if lhs_i == i32::MIN && rhs_i == -1 { + return 0; + } + (lhs_i % rhs_i) as u32 +} + +fn plan_paged_ell_addrs( + m: usize, + m_in: usize, + steps: usize, + ell_addr: usize, + lanes: usize, +) -> Result, String> { + if steps == 0 { + return Err("plan_paged_ell_addrs: steps=0".into()); + } + if m_in > m { + return Err(format!("plan_paged_ell_addrs: m_in({m_in}) > m({m})")); + } + let lanes = lanes.max(1); + + let avail = m - m_in; + let max_bus_cols_total = avail / steps; + let per_lane_capacity = max_bus_cols_total / lanes; + if per_lane_capacity < 3 { + return Err(format!( + "plan_paged_ell_addrs: insufficient capacity (need >=3 cols/lane for [addr_bits>=1,has_lookup,val], have per_lane_capacity={per_lane_capacity}; m={m}, m_in={m_in}, steps={steps}, lanes={lanes})" + )); + } + let max_addr_cols_per_page = per_lane_capacity - 2; + if max_addr_cols_per_page == 0 { + return Err("plan_paged_ell_addrs: max_addr_cols_per_page=0".into()); + } + if ell_addr == 0 { + return Err("plan_paged_ell_addrs: ell_addr=0".into()); + } + + let mut pages = Vec::new(); + let mut remaining = ell_addr; + while remaining > 0 { + let take = remaining.min(max_addr_cols_per_page); + pages.push(take); + remaining -= take; + } + Ok(pages) +} + +fn build_paged_shout_only_bus_zs_packed_div( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lane_data: &neo_memory::riscv::trace::ShoutLaneOverTime, + x_prefix: &[F], +) -> Result>, String> { + if ell_addr != 43 { + return Err(format!( + "build_paged_shout_only_bus_zs_packed_div: expected ell_addr=43 (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_paged_shout_only_bus_zs_packed_div: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.has_lookup.len() != t { + return Err("build_paged_shout_only_bus_zs_packed_div: lane length mismatch".into()); + } + + let page_ell_addrs = plan_paged_ell_addrs(m, m_in, t, ell_addr, /*lanes=*/ 1)?; + + let mut out = Vec::with_capacity(page_ell_addrs.len()); + let mut base_idx = 0usize; + for &page_ell_addr in page_ell_addrs.iter() { + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((page_ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err( + "build_paged_shout_only_bus_zs_packed_div: expected 1 shout instance and 0 twist instances".into(), + ); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + let addr_cols: Vec = cols.addr_bits.clone().collect(); + if addr_cols.len() != page_ell_addr { + return Err("build_paged_shout_only_bus_zs_packed_div: addr_bits len mismatch".into()); + } + + for j in 0..t { + let has = lane_data.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane_data.value[j]) } else { F::ZERO }; + + let mut packed = [F::ZERO; 43]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane_data.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let out_val = lane_data.value[j] as u32; + let expected_out = div_signed(lhs, rhs); + if out_val != expected_out { + return Err(format!( + "build_paged_shout_only_bus_zs_packed_div: lane.value mismatch at j={j} (got {out_val:#x}, expected {expected_out:#x})" + )); + } + + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + let lhs_abs = if lhs_sign == 0 { lhs } else { lhs.wrapping_neg() }; + let rhs_abs = if rhs == 0 { + 0u32 + } else if rhs_sign == 0 { + rhs + } else { + rhs.wrapping_neg() + }; + + let rhs_f = F::from_u64(rhs as u64); + let rhs_inv = if rhs_f == F::ZERO { F::ZERO } else { rhs_f.inverse() }; + let rhs_is_zero = if rhs == 0 { 1u32 } else { 0u32 }; + + let (q_abs, r_abs) = if rhs == 0 { + (0u32, 0u32) + } else { + (lhs_abs / rhs_abs, lhs_abs % rhs_abs) + }; + let q_is_zero = if q_abs == 0 { 1u32 } else { 0u32 }; + let q_f = F::from_u64(q_abs as u64); + let q_inv = if q_f == F::ZERO { F::ZERO } else { q_f.inverse() }; + + let diff = if rhs == 0 { 0u32 } else { r_abs.wrapping_sub(rhs_abs) }; + + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(q_abs as u64); + packed[3] = F::from_u64(r_abs as u64); + packed[4] = rhs_inv; + packed[5] = if rhs_is_zero == 1 { F::ONE } else { F::ZERO }; + packed[6] = if lhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[7] = if rhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[8] = q_inv; + packed[9] = if q_is_zero == 1 { F::ONE } else { F::ZERO }; + packed[10] = F::from_u64(diff as u64); + for bit in 0..32usize { + packed[11 + bit] = if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + } + + for (local_idx, &col_id) in addr_cols.iter().enumerate() { + let packed_idx = base_idx + local_idx; + if packed_idx >= ell_addr { + return Err("build_paged_shout_only_bus_zs_packed_div: paging overflow".into()); + } + z[bus.bus_cell(col_id, j)] = packed[packed_idx]; + } + } + + out.push(z); + base_idx += page_ell_addr; + } + + if base_idx != ell_addr { + return Err(format!( + "build_paged_shout_only_bus_zs_packed_div: paging mismatch (got base_idx={base_idx}, expected ell_addr={ell_addr})" + )); + } + + Ok(out) +} + +fn build_paged_shout_only_bus_zs_packed_rem( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lane_data: &neo_memory::riscv::trace::ShoutLaneOverTime, + x_prefix: &[F], +) -> Result>, String> { + if ell_addr != 43 { + return Err(format!( + "build_paged_shout_only_bus_zs_packed_rem: expected ell_addr=43 (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_paged_shout_only_bus_zs_packed_rem: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.has_lookup.len() != t { + return Err("build_paged_shout_only_bus_zs_packed_rem: lane length mismatch".into()); + } + + let page_ell_addrs = plan_paged_ell_addrs(m, m_in, t, ell_addr, /*lanes=*/ 1)?; + + let mut out = Vec::with_capacity(page_ell_addrs.len()); + let mut base_idx = 0usize; + for &page_ell_addr in page_ell_addrs.iter() { + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((page_ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err( + "build_paged_shout_only_bus_zs_packed_rem: expected 1 shout instance and 0 twist instances".into(), + ); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + let addr_cols: Vec = cols.addr_bits.clone().collect(); + if addr_cols.len() != page_ell_addr { + return Err("build_paged_shout_only_bus_zs_packed_rem: addr_bits len mismatch".into()); + } + + for j in 0..t { + let has = lane_data.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane_data.value[j]) } else { F::ZERO }; + + let mut packed = [F::ZERO; 43]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane_data.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let out_val = lane_data.value[j] as u32; + let expected_out = rem_signed(lhs, rhs); + if out_val != expected_out { + return Err(format!( + "build_paged_shout_only_bus_zs_packed_rem: lane.value mismatch at j={j} (got {out_val:#x}, expected {expected_out:#x})" + )); + } + + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + let lhs_abs = if lhs_sign == 0 { lhs } else { lhs.wrapping_neg() }; + let rhs_abs = if rhs == 0 { + 0u32 + } else if rhs_sign == 0 { + rhs + } else { + rhs.wrapping_neg() + }; + + let rhs_f = F::from_u64(rhs as u64); + let rhs_inv = if rhs_f == F::ZERO { F::ZERO } else { rhs_f.inverse() }; + let rhs_is_zero = if rhs == 0 { 1u32 } else { 0u32 }; + + let (q_abs, r_abs) = if rhs == 0 { + (0u32, 0u32) + } else { + (lhs_abs / rhs_abs, lhs_abs % rhs_abs) + }; + let r_is_zero = if r_abs == 0 { 1u32 } else { 0u32 }; + let r_f = F::from_u64(r_abs as u64); + let r_inv = if r_f == F::ZERO { F::ZERO } else { r_f.inverse() }; + + let diff = if rhs == 0 { 0u32 } else { r_abs.wrapping_sub(rhs_abs) }; + + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(q_abs as u64); + packed[3] = F::from_u64(r_abs as u64); + packed[4] = rhs_inv; + packed[5] = if rhs_is_zero == 1 { F::ONE } else { F::ZERO }; + packed[6] = if lhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[7] = if rhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[8] = r_inv; + packed[9] = if r_is_zero == 1 { F::ONE } else { F::ZERO }; + packed[10] = F::from_u64(diff as u64); + for bit in 0..32usize { + packed[11 + bit] = if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + } + + for (local_idx, &col_id) in addr_cols.iter().enumerate() { + let packed_idx = base_idx + local_idx; + if packed_idx >= ell_addr { + return Err("build_paged_shout_only_bus_zs_packed_rem: paging overflow".into()); + } + z[bus.bus_cell(col_id, j)] = packed[packed_idx]; + } + } + + out.push(z); + base_idx += page_ell_addr; + } + + if base_idx != ell_addr { + return Err(format!( + "build_paged_shout_only_bus_zs_packed_rem: paging mismatch (got base_idx={base_idx}, expected ell_addr={ell_addr})" + )); + } + + Ok(out) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_div_rem_semantics_redteam() { + // Same program as the e2e test; tamper: + // - DIV rhs_is_zero on a non-trivial lookup row, and + // - REM rhs_is_zero on a non-trivial lookup row. + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: -7 }, + RiscvInstruction::Lui { rd: 2, imm: 3 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 4, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Lui { rd: 1, imm: -1 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 5, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 6, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Lui { rd: 1, imm: -524_288 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: -1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 7, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 8, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 9, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 10, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let mut exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let tables = RiscvShoutTables::new(32); + let div_id = tables.opcode_to_id(RiscvOpcode::Div); + let rem_id = tables.opcode_to_id(RiscvOpcode::Rem); + for row in exec.rows.iter_mut() { + if row.active { + row.shout_events + .retain(|ev| ev.shout_id == div_id || ev.shout_id == rem_id); + } + } + let div_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + .. + }) + ) + }) + .count(); + let div_shout_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + .. + }) + ) + && row.shout_events.iter().any(|ev| ev.shout_id == div_id) + }) + .count(); + let rem_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + .. + }) + ) + }) + .count(); + let rem_shout_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + .. + }) + ) + && row.shout_events.iter().any(|ev| ev.shout_id == rem_id) + }) + .count(); + assert!(div_rows > 0, "expected at least one DIV row"); + assert!(rem_rows > 0, "expected at least one REM row"); + assert!( + div_shout_rows > 0 && div_shout_rows <= div_rows, + "native DIV shout coverage mismatch (div_rows={div_rows}, div_shout_rows={div_shout_rows})" + ); + assert!( + rem_shout_rows > 0 && rem_shout_rows <= rem_rows, + "native REM shout coverage mismatch (rem_rows={rem_rows}, rem_shout_rows={rem_shout_rows})" + ); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instances: DIV and REM packed, 1 lane each. + let t = exec.rows.len(); + let shout_table_ids = vec![div_id.0, rem_id.0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + assert_eq!(shout_lanes.len(), 2); + + let div_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 43, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Div, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let rem_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 43, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Rem, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let page_ell_addrs = + plan_paged_ell_addrs(ccs.m, layout.m_in, t, /*ell_addr=*/ 43, /*lanes=*/ 1).expect("paging plan"); + let page0_ell_addr = *page_ell_addrs.get(0).expect("non-empty paging plan"); + + let mut div_zs = + build_paged_shout_only_bus_zs_packed_div(ccs.m, layout.m_in, t, div_inst.d * div_inst.ell, &shout_lanes[0], &x) + .expect("DIV packed z"); + let j = shout_lanes[0] + .has_lookup + .iter() + .enumerate() + .find_map(|(idx, &has)| { + if !has { + return None; + } + let (lhs_u64, rhs_u64) = uninterleave_bits(shout_lanes[0].key[idx] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + if lhs != 0 || rhs != 0 { + Some(idx) + } else { + None + } + }) + .expect("expected a non-trivial DIV lookup"); + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((page0_ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let div_rhs_is_zero_col_id = cols + .addr_bits + .clone() + .nth(5) + .expect("expected addr_bits[5] for rhs_is_zero"); + let cell = bus.bus_cell(div_rhs_is_zero_col_id, j); + div_zs[0][cell] = if div_zs[0][cell] == F::ONE { F::ZERO } else { F::ONE }; + + let mut div_comms = Vec::with_capacity(div_zs.len()); + let mut div_mats = Vec::with_capacity(div_zs.len()); + for z in div_zs { + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + div_comms.push(l.commit(&Z)); + div_mats.push(Z); + } + let div_inst = LutInstance:: { + table_id: 0, + comms: div_comms, + ..div_inst + }; + let div_wit = LutWitness { mats: div_mats }; + + let mut rem_zs = + build_paged_shout_only_bus_zs_packed_rem(ccs.m, layout.m_in, t, rem_inst.d * rem_inst.ell, &shout_lanes[1], &x) + .expect("REM packed z"); + let j_rem = shout_lanes[1] + .has_lookup + .iter() + .enumerate() + .find_map(|(idx, &has)| { + if !has { + return None; + } + let (lhs_u64, rhs_u64) = uninterleave_bits(shout_lanes[1].key[idx] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + if lhs != 0 || rhs != 0 { + Some(idx) + } else { + None + } + }) + .expect("expected a non-trivial REM lookup"); + let rem_rhs_is_zero_col_id = cols + .addr_bits + .clone() + .nth(5) + .expect("expected addr_bits[5] for rhs_is_zero"); + let rem_cell = bus.bus_cell(rem_rhs_is_zero_col_id, j_rem); + rem_zs[0][rem_cell] = if rem_zs[0][rem_cell] == F::ONE { F::ZERO } else { F::ONE }; + + let mut rem_comms = Vec::with_capacity(rem_zs.len()); + let mut rem_mats = Vec::with_capacity(rem_zs.len()); + for z in rem_zs { + let Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z); + rem_comms.push(l.commit(&Z)); + rem_mats.push(Z); + } + let rem_inst = LutInstance:: { + table_id: 0, + comms: rem_comms, + ..rem_inst + }; + let rem_wit = LutWitness { mats: rem_mats }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(div_inst, div_wit), (rem_inst, rem_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-div-rem-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-div-rem-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed DIV/REM rhs_is_zero flags must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_divu_remu_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_divu_remu_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..6b0037a2 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_divu_remu_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,509 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::{Field, PrimeCharacteristicRing}; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn divu(lhs: u32, rhs: u32) -> u32 { + if rhs == 0 { + u32::MAX + } else { + lhs / rhs + } +} + +fn remu(lhs: u32, rhs: u32) -> u32 { + if rhs == 0 { + lhs + } else { + lhs % rhs + } +} + +fn build_shout_only_bus_z_packed_divu( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lane_data: &neo_memory::riscv::trace::ShoutLaneOverTime, + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z_packed_divu: expected ell_addr=38 (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z_packed_divu: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.has_lookup.len() != t { + return Err("build_shout_only_bus_z_packed_divu: lane length mismatch".into()); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z_packed_divu: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + for j in 0..t { + let has = lane_data.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane_data.value[j]) } else { F::ZERO }; + + let mut packed = [F::ZERO; 38]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane_data.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let quot = lane_data.value[j] as u32; + let expected_quot = divu(lhs, rhs); + if quot != expected_quot { + return Err(format!( + "build_shout_only_bus_z_packed_divu: lane.value mismatch at j={j} (got {quot:#x}, expected {expected_quot:#x})" + )); + } + + let rhs_f = F::from_u64(rhs as u64); + let rhs_inv = if rhs_f == F::ZERO { F::ZERO } else { rhs_f.inverse() }; + let rhs_is_zero = if rhs == 0 { 1u32 } else { 0u32 }; + + let rem = if rhs == 0 { + 0u32 + } else { + ((lhs as u64) % (rhs as u64)) as u32 + }; + let diff = rem.wrapping_sub(rhs); + + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(rem as u64); + packed[3] = rhs_inv; + packed[4] = if rhs_is_zero == 1 { F::ONE } else { F::ZERO }; + packed[5] = F::from_u64(diff as u64); + for bit in 0..32usize { + packed[6 + bit] = if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + + Ok(z) +} + +fn build_shout_only_bus_z_packed_remu( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lane_data: &neo_memory::riscv::trace::ShoutLaneOverTime, + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z_packed_remu: expected ell_addr=38 (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z_packed_remu: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.has_lookup.len() != t { + return Err("build_shout_only_bus_z_packed_remu: lane length mismatch".into()); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z_packed_remu: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + for j in 0..t { + let has = lane_data.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane_data.value[j]) } else { F::ZERO }; + + let mut packed = [F::ZERO; 38]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane_data.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let rem = lane_data.value[j] as u32; + let expected_rem = remu(lhs, rhs); + if rem != expected_rem { + return Err(format!( + "build_shout_only_bus_z_packed_remu: lane.value mismatch at j={j} (got {rem:#x}, expected {expected_rem:#x})" + )); + } + + let rhs_f = F::from_u64(rhs as u64); + let rhs_inv = if rhs_f == F::ZERO { F::ZERO } else { rhs_f.inverse() }; + let rhs_is_zero = if rhs == 0 { 1u32 } else { 0u32 }; + + let quot = if rhs == 0 { + 0u32 + } else { + (lhs as u64 / rhs as u64) as u32 + }; + let diff = rem.wrapping_sub(rhs); + + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(quot as u64); + packed[3] = rhs_inv; + packed[4] = if rhs_is_zero == 1 { F::ONE } else { F::ZERO }; + packed[5] = F::from_u64(diff as u64); + for bit in 0..32usize { + packed[6 + bit] = if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_divu_remu_semantics_redteam() { + // Same program as the e2e test; tamper DIVU diff bit 0. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 91, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 7, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + rd: 4, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + rd: 5, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + rd: 6, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let mut exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 16).expect("from_trace_padded_pow2"); + let tables = RiscvShoutTables::new(32); + let divu_id = tables.opcode_to_id(RiscvOpcode::Divu); + let remu_id = tables.opcode_to_id(RiscvOpcode::Remu); + for row in exec.rows.iter_mut() { + if row.active { + row.shout_events + .retain(|ev| ev.shout_id == divu_id || ev.shout_id == remu_id); + } + } + let divu_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + .. + }) + ) + }) + .count(); + let divu_shout_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + .. + }) + ) + && row.shout_events.iter().any(|ev| ev.shout_id == divu_id) + }) + .count(); + let remu_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + .. + }) + ) + }) + .count(); + let remu_shout_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + .. + }) + ) + && row.shout_events.iter().any(|ev| ev.shout_id == remu_id) + }) + .count(); + assert!(divu_rows > 0, "expected at least one DIVU row"); + assert!(remu_rows > 0, "expected at least one REMU row"); + assert!( + divu_shout_rows > 0 && divu_shout_rows <= divu_rows, + "native DIVU shout coverage mismatch (divu_rows={divu_rows}, divu_shout_rows={divu_shout_rows})" + ); + assert!( + remu_shout_rows > 0 && remu_shout_rows <= remu_rows, + "native REMU shout coverage mismatch (remu_rows={remu_rows}, remu_shout_rows={remu_shout_rows})" + ); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 38usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instances: DIVU and REMU packed, 1 lane each (tamper DIVU diff_bit0). + let t = exec.rows.len(); + let shout_table_ids = vec![divu_id.0, remu_id.0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + assert_eq!(shout_lanes.len(), 2); + + let divu_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Divu, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let remu_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Remu, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut divu_z = + build_shout_only_bus_z_packed_divu(ccs.m, layout.m_in, t, divu_inst.d * divu_inst.ell, &shout_lanes[0], &x) + .expect("DIVU packed z"); + let j = shout_lanes[0] + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one DIVU lookup"); + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 38usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let diff_bit0_col_id = cols + .addr_bits + .clone() + .nth(6) + .expect("expected addr_bits[6] for diff bit 0"); + let cell = bus.bus_cell(diff_bit0_col_id, j); + divu_z[cell] = if divu_z[cell] == F::ONE { F::ZERO } else { F::ONE }; + + let divu_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &divu_z); + let divu_c = l.commit(&divu_Z); + let divu_inst = LutInstance:: { + table_id: 0, + comms: vec![divu_c], + ..divu_inst + }; + let divu_wit = LutWitness { mats: vec![divu_Z] }; + + let remu_z = + build_shout_only_bus_z_packed_remu(ccs.m, layout.m_in, t, remu_inst.d * remu_inst.ell, &shout_lanes[1], &x) + .expect("REMU packed z"); + let remu_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &remu_z); + let remu_c = l.commit(&remu_Z); + let remu_inst = LutInstance:: { + table_id: 0, + comms: vec![remu_c], + ..remu_inst + }; + let remu_wit = LutWitness { mats: vec![remu_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(divu_inst, divu_wit), (remu_inst, remu_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-divu-remu-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-divu-remu-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed DIVU diff bit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_eq_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_eq_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..0b3fbc76 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_eq_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,279 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemory, + RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 35 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=35 for packed EQ (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_shout_only_bus_z: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + // Packed-key layout (ell_addr=35): + // [lhs_u32, rhs_u32, borrow_bit, diff_bits[0..32]]. + let mut packed = [F::ZERO; 35]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let borrow = if lhs < rhs { 1u32 } else { 0u32 }; + let diff = lhs.wrapping_sub(rhs); + + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = if borrow == 1 { F::ONE } else { F::ZERO }; + for bit in 0..32usize { + packed[3 + bit] = if ((diff >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + } + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_no_shared_cpu_bus_shout_eq_semantics_redteam() { + // Program (use BEQ to generate an `Eq` Shout event; there is no dedicated `Eq` ALU instruction encoding): + // - LUI x1, 0 (x1 = 0) + // - LUI x2, 1 (x2 = 4096) + // - BEQ x1, x2, +8 (not taken; EQ=0) <-- diff != 0, inv is constrained + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0 }, + RiscvInstruction::Lui { rd: 2, imm: 1 }, + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 1, + rs2: 2, + imm: 8, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 35usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: EQ table, 1 lane (tamper the packed borrow bit witness). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Eq).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let eq_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 35, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Eq, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut eq_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ eq_lut_inst.d * eq_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("EQ Shout z"); + + // Find a lookup row with lhs != rhs, then flip the borrow bit (addr_bits[2]). + let lane0 = &shout_lanes[0]; + let mut tamper_j: Option = None; + for j in 0..t { + if !lane0.has_lookup[j] { + continue; + } + let (lhs_u64, rhs_u64) = uninterleave_bits(lane0.key[j] as u128); + if (lhs_u64 as u32) != (rhs_u64 as u32) { + tamper_j = Some(j); + break; + } + } + let j = tamper_j.expect("expected at least one EQ lookup with lhs != rhs"); + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 35usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let borrow_col_id = cols.addr_bits.clone().nth(2).expect("borrow col id"); + let borrow_cell = bus.bus_cell(borrow_col_id, j); + eq_z[borrow_cell] = if eq_z[borrow_cell] == F::ZERO { F::ONE } else { F::ZERO }; + + let eq_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &eq_z); + let eq_c = l.commit(&eq_Z); + let eq_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![eq_c], + ..eq_lut_inst + }; + let eq_lut_wit = LutWitness { mats: vec![eq_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(eq_lut_inst, eq_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // The prover may either: + // - reject because the witness no longer satisfies the protocol invariants, or + // - emit a proof that fails verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-eq-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-eq-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed EQ borrow witness must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mul_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mul_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..1d9c672c --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mul_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,314 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 34 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=34 for packed MUL (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let wide = (lhs as u64) * (rhs as u64); + let carry = (wide >> 32) as u32; + + // Packed-key layout (ell_addr=34): [lhs_u32, rhs_u32, carry_bits[0..32]]. + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + for bit in 0..32 { + packed[2 + bit] = if ((carry >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_mul_semantics_redteam() { + // Program: + // - LUI x1, 16 (x1 = 65536) + // - LUI x2, 16 (x2 = 65536) + // - MUL x3, x1, x2 (hi=1, lo=0) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 16 }, + RiscvInstruction::Lui { rd: 2, imm: 16 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let mut exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let mul_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Mul); + for row in exec.rows.iter_mut() { + if row.active { + row.shout_events.retain(|ev| ev.shout_id == mul_id); + } + } + let mul_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + .. + }) + ) + }) + .count(); + let mul_shout_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + .. + }) + ) + && row.shout_events.iter().any(|ev| ev.shout_id == mul_id) + }) + .count(); + assert!(mul_rows > 0, "expected at least one MUL row"); + assert_eq!( + mul_shout_rows, mul_rows, + "native MUL shout coverage mismatch (mul_rows={mul_rows}, shout_rows={mul_shout_rows})" + ); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 34usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: MUL table, 1 lane (tamper one carry bit). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Mul).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let mul_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 34, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Mul, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut mul_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ mul_lut_inst.d * mul_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("MUL Shout z"); + + // Flip carry bit 0 on any active lookup row. + let lane0 = &shout_lanes[0]; + let j = lane0 + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one MUL lookup"); + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 34usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let carry_bit0_col_id = cols + .addr_bits + .clone() + .nth(2) + .expect("expected addr_bits[2] for carry bit 0"); + let cell = bus.bus_cell(carry_bit0_col_id, j); + mul_z[cell] = if mul_z[cell] == F::ONE { F::ZERO } else { F::ONE }; + + let mul_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &mul_z); + let mul_c = l.commit(&mul_Z); + let mul_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![mul_c], + ..mul_lut_inst + }; + let mul_lut_wit = LutWitness { mats: vec![mul_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(mul_lut_inst, mul_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-mul-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-mul-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed MUL carry bit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mulh_mulhsu_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mulh_mulhsu_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..9d46101b --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mulh_mulhsu_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,519 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn mulh_hi_signed(lhs: u32, rhs: u32) -> u32 { + let a = lhs as i32 as i64; + let b = rhs as i32 as i64; + let p = a * b; + (p >> 32) as i32 as u32 +} + +fn mulhsu_hi_signed(lhs: u32, rhs: u32) -> u32 { + let a = lhs as i32 as i64; + let b = rhs as i64; + let p = a * b; + (p >> 32) as i32 as u32 +} + +fn build_shout_only_bus_z_packed_mulh( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lane_data: &neo_memory::riscv::trace::ShoutLaneOverTime, + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z_packed_mulh: expected ell_addr=38 (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z_packed_mulh: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.has_lookup.len() != t { + return Err("build_shout_only_bus_z_packed_mulh: lane length mismatch".into()); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z_packed_mulh: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + for j in 0..t { + let has = lane_data.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane_data.value[j]) } else { F::ZERO }; + + let mut packed = [F::ZERO; 38]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane_data.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let val = lane_data.value[j] as u32; + let expected_val = mulh_hi_signed(lhs, rhs); + if val != expected_val { + return Err(format!( + "build_shout_only_bus_z_packed_mulh: lane.value mismatch at j={j} (got {val:#x}, expected {expected_val:#x})" + )); + } + + let uprod = (lhs as u64) * (rhs as u64); + let lo = (uprod & 0xffff_ffff) as u32; + let hi = (uprod >> 32) as u32; + + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + + let diff = + (val as i128) - (hi as i128) + (lhs_sign as i128) * (rhs as i128) + (rhs_sign as i128) * (lhs as i128); + let two32 = 1_i128 << 32; + if diff < 0 || diff % two32 != 0 { + return Err(format!( + "build_shout_only_bus_z_packed_mulh: invalid k at j={j} (diff={diff})" + )); + } + let k = (diff / two32) as u32; + if k > 2 { + return Err(format!( + "build_shout_only_bus_z_packed_mulh: expected k in {{0,1,2}} at j={j}, got k={k}" + )); + } + + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(hi as u64); + packed[3] = if lhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[4] = if rhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[5] = F::from_u64(k as u64); + for bit in 0..32usize { + packed[6 + bit] = if ((lo >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + + Ok(z) +} + +fn build_shout_only_bus_z_packed_mulhsu( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lane_data: &neo_memory::riscv::trace::ShoutLaneOverTime, + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 37 { + return Err(format!( + "build_shout_only_bus_z_packed_mulhsu: expected ell_addr=37 (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z_packed_mulhsu: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.has_lookup.len() != t { + return Err("build_shout_only_bus_z_packed_mulhsu: lane length mismatch".into()); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z_packed_mulhsu: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let cols = &bus.shout_cols[0].lanes[0]; + for j in 0..t { + let has = lane_data.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane_data.value[j]) } else { F::ZERO }; + + let mut packed = [F::ZERO; 37]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane_data.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let val = lane_data.value[j] as u32; + let expected_val = mulhsu_hi_signed(lhs, rhs); + if val != expected_val { + return Err(format!( + "build_shout_only_bus_z_packed_mulhsu: lane.value mismatch at j={j} (got {val:#x}, expected {expected_val:#x})" + )); + } + + let uprod = (lhs as u64) * (rhs as u64); + let lo = (uprod & 0xffff_ffff) as u32; + let hi = (uprod >> 32) as u32; + let lhs_sign = (lhs >> 31) & 1; + + let diff = (val as i128) - (hi as i128) + (lhs_sign as i128) * (rhs as i128); + let two32 = 1_i128 << 32; + if diff < 0 || diff % two32 != 0 { + return Err(format!( + "build_shout_only_bus_z_packed_mulhsu: invalid borrow at j={j} (diff={diff})" + )); + } + let borrow = (diff / two32) as u32; + if borrow > 1 { + return Err(format!( + "build_shout_only_bus_z_packed_mulhsu: expected borrow in {{0,1}} at j={j}, got borrow={borrow}" + )); + } + + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(hi as u64); + packed[3] = if lhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[4] = if borrow == 1 { F::ONE } else { F::ZERO }; + for bit in 0..32usize { + packed[5 + bit] = if ((lo >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_mulh_mulhsu_semantics_redteam() { + // Same program as the e2e test; tamper a single MULH lo bit. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: -7, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: -3, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulh, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 5, + rs1: 0, + imm: 13, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhsu, + rd: 6, + rs1: 1, + rs2: 5, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let mut exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let tables = RiscvShoutTables::new(32); + let mulh_id = tables.opcode_to_id(RiscvOpcode::Mulh); + let mulhsu_id = tables.opcode_to_id(RiscvOpcode::Mulhsu); + for row in exec.rows.iter_mut() { + if row.active { + row.shout_events + .retain(|ev| ev.shout_id == mulh_id || ev.shout_id == mulhsu_id); + } + } + let mulh_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Mulh, + .. + }) + ) + }) + .count(); + let mulh_shout_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Mulh, + .. + }) + ) + && row.shout_events.iter().any(|ev| ev.shout_id == mulh_id) + }) + .count(); + let mulhsu_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhsu, + .. + }) + ) + }) + .count(); + let mulhsu_shout_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhsu, + .. + }) + ) + && row.shout_events.iter().any(|ev| ev.shout_id == mulhsu_id) + }) + .count(); + assert!(mulh_rows > 0, "expected at least one MULH row"); + assert!(mulhsu_rows > 0, "expected at least one MULHSU row"); + assert_eq!( + mulh_shout_rows, mulh_rows, + "native MULH shout coverage mismatch (mulh_rows={mulh_rows}, mulh_shout_rows={mulh_shout_rows})" + ); + assert_eq!( + mulhsu_shout_rows, mulhsu_rows, + "native MULHSU shout coverage mismatch (mulhsu_rows={mulhsu_rows}, mulhsu_shout_rows={mulhsu_shout_rows})" + ); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 38usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instances: MULH and MULHSU packed, 1 lane each (tamper MULH lo bit 0). + let t = exec.rows.len(); + let shout_table_ids = vec![mulh_id.0, mulhsu_id.0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + assert_eq!(shout_lanes.len(), 2); + + let mulh_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Mulh, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let mulhsu_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 37, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Mulhsu, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut mulh_z = + build_shout_only_bus_z_packed_mulh(ccs.m, layout.m_in, t, mulh_inst.d * mulh_inst.ell, &shout_lanes[0], &x) + .expect("MULH packed z"); + let j = shout_lanes[0] + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one MULH lookup"); + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 38usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let lo_bit0_col_id = cols + .addr_bits + .clone() + .nth(6) + .expect("expected addr_bits[6] for lo bit 0"); + let cell = bus.bus_cell(lo_bit0_col_id, j); + mulh_z[cell] = if mulh_z[cell] == F::ONE { F::ZERO } else { F::ONE }; + + let mulh_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &mulh_z); + let mulh_c = l.commit(&mulh_Z); + let mulh_inst = LutInstance:: { + table_id: 0, + comms: vec![mulh_c], + ..mulh_inst + }; + let mulh_wit = LutWitness { mats: vec![mulh_Z] }; + + let mulhsu_z = build_shout_only_bus_z_packed_mulhsu( + ccs.m, + layout.m_in, + t, + mulhsu_inst.d * mulhsu_inst.ell, + &shout_lanes[1], + &x, + ) + .expect("MULHSU packed z"); + let mulhsu_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &mulhsu_z); + let mulhsu_c = l.commit(&mulhsu_Z); + let mulhsu_inst = LutInstance:: { + table_id: 0, + comms: vec![mulhsu_c], + ..mulhsu_inst + }; + let mulhsu_wit = LutWitness { mats: vec![mulhsu_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(mulh_inst, mulh_wit), (mulhsu_inst, mulhsu_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-mulh-mulhsu-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-mulh-mulhsu-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed MULH lo bit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mulhu_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mulhu_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..ef011cfc --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_mulhu_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,314 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 34 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=34 for packed MULHU (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let wide = (lhs as u64) * (rhs as u64); + let lo = (wide & 0xffff_ffff) as u32; + + // Packed-key layout (ell_addr=34): [lhs_u32, rhs_u32, lo_bits[0..32]]. + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + for bit in 0..32 { + packed[2 + bit] = if ((lo >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_mulhu_semantics_redteam() { + // Program: + // - LUI x1, 16 (x1 = 65536) + // - LUI x2, 16 (x2 = 65536) + // - MULHU x3, x1, x2 (hi=1, lo=0) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 16 }, + RiscvInstruction::Lui { rd: 2, imm: 16 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhu, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let mut exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let mulhu_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Mulhu); + for row in exec.rows.iter_mut() { + if row.active { + row.shout_events.retain(|ev| ev.shout_id == mulhu_id); + } + } + let mulhu_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhu, + .. + }) + ) + }) + .count(); + let mulhu_shout_rows = exec + .rows + .iter() + .filter(|row| { + row.active + && matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhu, + .. + }) + ) + && row.shout_events.iter().any(|ev| ev.shout_id == mulhu_id) + }) + .count(); + assert!(mulhu_rows > 0, "expected at least one MULHU row"); + assert_eq!( + mulhu_shout_rows, mulhu_rows, + "native MULHU shout coverage mismatch (mulhu_rows={mulhu_rows}, shout_rows={mulhu_shout_rows})" + ); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 34usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: MULHU table, 1 lane (tamper one low-bit). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Mulhu).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let mulhu_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 34, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Mulhu, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut mulhu_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ mulhu_lut_inst.d * mulhu_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("MULHU Shout z"); + + // Flip lo bit 0 on any active lookup row. + let lane0 = &shout_lanes[0]; + let j = lane0 + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one MULHU lookup"); + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 34usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let lo_bit0_col_id = cols + .addr_bits + .clone() + .nth(2) + .expect("expected addr_bits[2] for lo bit 0"); + let cell = bus.bus_cell(lo_bit0_col_id, j); + mulhu_z[cell] = if mulhu_z[cell] == F::ONE { F::ZERO } else { F::ONE }; + + let mulhu_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &mulhu_z); + let mulhu_c = l.commit(&mulhu_Z); + let mulhu_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![mulhu_c], + ..mulhu_lut_inst + }; + let mulhu_lut_wit = LutWitness { mats: vec![mulhu_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(mulhu_lut_inst, mulhu_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-mulhu-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-mulhu-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed MULHU lo bit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sll_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sll_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..6d927044 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sll_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,278 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=38 for packed SLL (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let shamt = (rhs_u64 as u32) & 0x1F; + let wide = (lhs as u64) << shamt; + let carry = (wide >> 32) as u32; + + // Packed-key layout (ell_addr=38): + // [lhs_u32, shamt_bits[0..5], carry_bits[0..32]]. + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + for bit in 0..5 { + packed[1 + bit] = if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + for bit in 0..32 { + packed[6 + bit] = if ((carry >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_sll_semantics_redteam() { + // Program: + // - LUI x1, 1 (x1 = 4096) + // - SLLI x2, x1, 3 (x2 = 32768) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 1 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sll, + rd: 2, + rs1: 1, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 38usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SLL table, 1 lane (tamper one carry bit). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sll).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let sll_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sll, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut sll_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sll_lut_inst.d * sll_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SLL Shout z"); + + // Flip carry bit 0 on any active lookup row. + let lane0 = &shout_lanes[0]; + let j = lane0 + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one SLL lookup"); + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 38usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let carry_bit0_col_id = cols + .addr_bits + .clone() + .nth(6) + .expect("expected addr_bits[6] for carry bit 0"); + let cell = bus.bus_cell(carry_bit0_col_id, j); + sll_z[cell] = if sll_z[cell] == F::ONE { F::ZERO } else { F::ONE }; + + let sll_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sll_z); + let sll_c = l.commit(&sll_Z); + let sll_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sll_c], + ..sll_lut_inst + }; + let sll_lut_wit = LutWitness { mats: vec![sll_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sll_lut_inst, sll_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // The prover may either: + // - reject because the tampered witness no longer satisfies the protocol invariants, or + // - emit a proof that fails verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sll-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sll-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed SLL carry bit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_slt_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_slt_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..c35a4aea --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_slt_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,283 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 37 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=37 for packed SLT (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + + let lhs_sign = (lhs >> 31) & 1; + let rhs_sign = (rhs >> 31) & 1; + + let lhs_b = lhs ^ 0x8000_0000; + let rhs_b = rhs ^ 0x8000_0000; + let diff = lhs_b.wrapping_sub(rhs_b); + + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(diff as u64); + packed[3] = if lhs_sign == 1 { F::ONE } else { F::ZERO }; + packed[4] = if rhs_sign == 1 { F::ONE } else { F::ZERO }; + for bit in 0..32 { + let b = (diff >> bit) & 1; + packed[5 + bit] = if b == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_no_shared_cpu_bus_shout_slt_semantics_redteam() { + // Program: + // - LUI x1, 0x80000 (x1 = 0x8000_0000) + // - SLTI x2, x1, 0 (1) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Slt, + rd: 2, + rs1: 1, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 37usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SLT table, 1 lane (tamper one diff bit). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Slt).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let slt_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 37, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Slt, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut slt_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ slt_lut_inst.d * slt_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SLT Shout z"); + + // Flip the first diff bit on any active lookup row. + let lane0 = &shout_lanes[0]; + let j = lane0 + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one SLT lookup"); + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 37usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let diff_bit0_col_id = cols + .addr_bits + .clone() + .nth(5) + .expect("expected addr_bits[5] for diff bit 0"); + let cell = bus.bus_cell(diff_bit0_col_id, j); + slt_z[cell] = if slt_z[cell] == F::ONE { F::ZERO } else { F::ONE }; + + let slt_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &slt_z); + let slt_c = l.commit(&slt_Z); + let slt_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![slt_c], + ..slt_lut_inst + }; + let slt_lut_wit = LutWitness { mats: vec![slt_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(slt_lut_inst, slt_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // The prover may either: + // - reject because the tampered witness no longer satisfies the protocol invariants, or + // - emit a proof that fails verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-slt-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-slt-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed SLT diff bit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sltu_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sltu_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..725ad966 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sltu_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,278 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 35 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=35 for packed SLTU (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + // addr_bits: [lhs_u32, rhs_u32, diff_u32, diff_bits[0..32]] + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let rhs = rhs_u64 as u32; + let diff = lhs.wrapping_sub(rhs); + + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + packed[1] = F::from_u64(rhs as u64); + packed[2] = F::from_u64(diff as u64); + for bit in 0..32 { + let b = (diff >> bit) & 1; + packed[3 + bit] = if b == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_no_shared_cpu_bus_shout_sltu_semantics_redteam() { + // Program: + // - LUI x1, 0 (x1 = 0) + // - LUI x2, 1 (x2 = 4096) + // - SLTU x3, x1, x2 (1) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0 }, + RiscvInstruction::Lui { rd: 2, imm: 1 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sltu, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 35usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SLTU table, 1 lane (tamper one diff bit). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sltu).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let sltu_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 35, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sltu, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut sltu_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sltu_lut_inst.d * sltu_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SLTU Shout z"); + + // Flip the first diff bit on any active lookup row. + let lane0 = &shout_lanes[0]; + let j = lane0 + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one SLTU lookup"); + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 35usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let diff_bit0_col_id = cols + .addr_bits + .clone() + .nth(3) + .expect("expected addr_bits[3] for diff bit 0"); + let cell = bus.bus_cell(diff_bit0_col_id, j); + sltu_z[cell] = if sltu_z[cell] == F::ONE { F::ZERO } else { F::ONE }; + + let sltu_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sltu_z); + let sltu_c = l.commit(&sltu_Z); + let sltu_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sltu_c], + ..sltu_lut_inst + }; + let sltu_lut_wit = LutWitness { mats: vec![sltu_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sltu_lut_inst, sltu_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // The prover may either: + // - reject because the tampered witness no longer satisfies the protocol invariants, or + // - emit a proof that fails verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sltu-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sltu-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed SLTU diff bit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sra_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sra_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..90bcbed8 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sra_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,307 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=38 for packed SRA (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let shamt = (rhs_u64 as u32) & 0x1F; + let sign = (lhs >> 31) & 1; + + let lhs_signed: i64 = if sign == 1 { + (lhs as i64) - (1i64 << 32) + } else { + lhs as i64 + }; + let val_u32 = lane.value[j] as u32; + let val_signed: i64 = (val_u32 as i64) - (sign as i64) * (1i64 << 32); + let pow2: i64 = 1i64 << shamt; + let rem_i64 = lhs_signed - val_signed * pow2; + if rem_i64 < 0 { + return Err("build_shout_only_bus_z: negative SRA remainder".into()); + } + let rem = rem_i64 as u64; + + // Packed-key layout (ell_addr=38): + // [lhs_u32, shamt_bits[0..5], sign_bit, rem_bits[0..31]]. + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + for bit in 0..5 { + packed[1 + bit] = if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + packed[6] = if sign == 1 { F::ONE } else { F::ZERO }; + for bit in 0..31 { + packed[7 + bit] = if ((rem >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_sra_semantics_redteam() { + // Program: + // - LUI x1, 0x80000 (x1 = 0x8000_0000) + // - SRAI x2, x1, 3 (x2 = 0xF000_0000) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x80000 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Sra, + rd: 2, + rs1: 1, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 38usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Shout lane data for SRA (used to coordinate a linkage-preserving tamper). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sra).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + let lane0 = &shout_lanes[0]; + let j = lane0 + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one SRA lookup"); + + let (_lhs_u64, rhs_u64) = uninterleave_bits(lane0.key[j] as u128); + let shamt = (rhs_u64 as u32) & 0x1F; + assert!(shamt > 0, "redteam requires shamt>0"); + let old_val = lane0.value[j]; + assert!(old_val > 0, "redteam requires val>0"); + let new_val = old_val - 1; + + // Tamper CPU trace shout_val at row j (CCS trace wiring doesn't constrain shout_val). + let val_idx = layout.cell(layout.trace.shout_val, j); + w[val_idx - layout.m_in] = F::from_u64(new_val); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (tampered). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SRA table, 1 lane (tamper remainder-bound while preserving value equation + linkage). + let sra_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sra, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut sra_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sra_lut_inst.d * sra_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SRA Shout z"); + + // Set rem_bit[shamt] = 1 on the executed lookup row. + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 38usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let rem_bit_col_id = cols + .addr_bits + .clone() + .nth(7 + shamt as usize) + .expect("expected addr_bits[7+shamt] for rem_bit[shamt]"); + let cell = bus.bus_cell(rem_bit_col_id, j); + sra_z[cell] = F::ONE; + + // Adjust Shout val to preserve the value equation and trace↔Shout linkage. + let val_cell = bus.bus_cell(cols.primary_val(), j); + sra_z[val_cell] = F::from_u64(new_val); + + let sra_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sra_z); + let sra_c = l.commit(&sra_Z); + let sra_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sra_c], + ..sra_lut_inst + }; + let sra_lut_wit = LutWitness { mats: vec![sra_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sra_lut_inst, sra_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // The prover may either: + // - reject because the tampered witness no longer satisfies the protocol invariants, or + // - emit a proof that fails verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sra-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sra-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed SRA remainder must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_srl_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_srl_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..5347a1e0 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_srl_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,301 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer, widen_ccs_cols_for_test}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 38 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=38 for packed SRL (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = lane_data + .get(lane_idx) + .ok_or_else(|| format!("missing lane_data[{lane_idx}]"))?; + for j in 0..t { + let has_lookup = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has_lookup { F::ONE } else { F::ZERO }; + + if has_lookup { + z[bus.bus_cell(cols.primary_val(), j)] = F::from_u64(lane.value[j]); + } + + if has_lookup { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs = lhs_u64 as u32; + let shamt = (rhs_u64 as u32) & 0x1F; + let rem: u32 = if shamt == 0 { + 0 + } else { + let mask = (1u64 << shamt) - 1; + ((lhs as u64) & mask) as u32 + }; + + // Packed-key layout (ell_addr=38): + // [lhs_u32, shamt_bits[0..5], rem_bits[0..32]]. + let mut packed = vec![F::ZERO; ell_addr]; + packed[0] = F::from_u64(lhs as u64); + for bit in 0..5 { + packed[1 + bit] = if ((shamt >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + for bit in 0..32 { + packed[6 + bit] = if ((rem >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; + } + + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_srl_semantics_redteam() { + // Program: + // - LUI x1, 1 (x1 = 4096) + // - SRLI x2, x1, 3 (x2 = 512) + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 1 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Srl, + rd: 2, + rs1: 1, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let mut ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let min_m = layout + .m_in + .checked_add((/*bus_cols=*/ 38usize + 2usize).checked_mul(exec.rows.len()).expect("bus cols * steps")) + .expect("m_in + bus region"); + widen_ccs_cols_for_test(&mut ccs, min_m); + w.resize(ccs.m - layout.m_in, F::ZERO); + + // Shout lane data for SRL (used to coordinate a linkage-preserving tamper). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Srl).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + let lane0 = &shout_lanes[0]; + let j = lane0 + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one SRL lookup"); + + // Pick the executed lookup (lhs, shamt) and tamper it so: + // - value equation still holds (lhs = val*2^shamt + rem) + // - trace↔Shout linkage still holds (we tamper both CPU shout_val and Shout val) + // - but the remainder-bound check fails (we set rem_bit[shamt] = 1, so rem >= 2^shamt) + let (_lhs_u64, rhs_u64) = uninterleave_bits(lane0.key[j] as u128); + let shamt = (rhs_u64 as u32) & 0x1F; + assert!(shamt > 0, "redteam requires shamt>0"); + let old_val = lane0.value[j]; + assert!(old_val > 0, "redteam requires val>0"); + let new_val = old_val - 1; + + // Tamper CPU trace shout_val at row j (CCS trace wiring doesn't constrain shout_val). + let val_idx = layout.cell(layout.trace.shout_val, j); + w[val_idx - layout.m_in] = F::from_u64(new_val); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (tampered). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SRL table, 1 lane (tamper remainder-bound while preserving value equation + linkage). + let srl_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 38, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Srl, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + + let mut srl_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ srl_lut_inst.d * srl_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SRL Shout z"); + + // Set rem_bit[shamt] = 1 on the executed lookup row. + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 38usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let rem_bit_col_id = cols + .addr_bits + .clone() + .nth(6 + shamt as usize) + .expect("expected addr_bits[6+shamt] for rem_bit[shamt]"); + let cell = bus.bus_cell(rem_bit_col_id, j); + srl_z[cell] = F::ONE; + + // Adjust Shout val to preserve the value equation and trace↔Shout linkage. + let val_cell = bus.bus_cell(cols.primary_val(), j); + srl_z[val_cell] = F::from_u64(new_val); + + let srl_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &srl_z); + let srl_c = l.commit(&srl_Z); + let srl_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![srl_c], + ..srl_lut_inst + }; + let srl_lut_wit = LutWitness { mats: vec![srl_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(srl_lut_inst, srl_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // The prover may either: + // - reject because the tampered witness no longer satisfies the protocol invariants, or + // - emit a proof that fails verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-srl-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-srl-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed SRL remainder must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sub_no_shared_cpu_bus_semantics_redteam.rs b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sub_no_shared_cpu_bus_semantics_redteam.rs new file mode 100644 index 00000000..c164aa47 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_shout/semantics_redteam/riscv_trace_shout_sub_no_shared_cpu_bus_semantics_redteam.rs @@ -0,0 +1,258 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use neo_ajtai::Commitment as Cmt; +use neo_ccs::relations::{McsInstance, McsWitness}; +use neo_ccs::traits::SModuleHomomorphism; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::shard::{fold_shard_prove, fold_shard_verify}; +use neo_math::F; +use neo_memory::cpu::build_bus_layout_for_instances_with_shout_and_twist_lanes; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, uninterleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_memory::witness::{LutInstance, LutTableSpec, LutWitness, StepInstanceBundle, StepWitnessBundle}; +use neo_params::NeoParams; +use neo_transcript::Poseidon2Transcript; +use neo_transcript::Transcript; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; + +use crate::suite::{default_mixers, setup_ajtai_committer}; + +fn build_shout_only_bus_z( + m: usize, + m_in: usize, + t: usize, + ell_addr: usize, + lanes: usize, + lane_data: &[neo_memory::riscv::trace::ShoutLaneOverTime], + x_prefix: &[F], +) -> Result, String> { + if ell_addr != 3 { + return Err(format!( + "build_shout_only_bus_z: expected ell_addr=3 for packed SUB (got ell_addr={ell_addr})" + )); + } + if x_prefix.len() != m_in { + return Err(format!( + "build_shout_only_bus_z: x_prefix.len()={} != m_in={}", + x_prefix.len(), + m_in + )); + } + if lane_data.len() != lanes { + return Err(format!( + "build_shout_only_bus_z: lane_data.len()={} != lanes={}", + lane_data.len(), + lanes + )); + } + + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + m, + m_in, + t, + core::iter::once((ell_addr, lanes)), + core::iter::empty::<(usize, usize)>(), + )?; + if bus.shout_cols.len() != 1 || !bus.twist_cols.is_empty() { + return Err("build_shout_only_bus_z: expected 1 shout instance and 0 twist instances".into()); + } + + let mut z = vec![F::ZERO; m]; + z[..m_in].copy_from_slice(x_prefix); + + let shout = &bus.shout_cols[0]; + for (lane_idx, cols) in shout.lanes.iter().enumerate() { + let lane = &lane_data[lane_idx]; + if lane.has_lookup.len() != t { + return Err("build_shout_only_bus_z: lane length mismatch".into()); + } + for j in 0..t { + let has = lane.has_lookup[j]; + z[bus.bus_cell(cols.has_lookup, j)] = if has { F::ONE } else { F::ZERO }; + z[bus.bus_cell(cols.primary_val(), j)] = if has { F::from_u64(lane.value[j]) } else { F::ZERO }; + + // Packed-key layout: [lhs_u32, rhs_u32, borrow_bit] + let mut packed = [F::ZERO; 3]; + if has { + let (lhs_u64, rhs_u64) = uninterleave_bits(lane.key[j] as u128); + let lhs_u32 = lhs_u64 as u32; + let rhs_u32 = rhs_u64 as u32; + let borrow = (lhs_u32 < rhs_u32) as u64; + packed[0] = F::from_u64(lhs_u32 as u64); + packed[1] = F::from_u64(rhs_u32 as u64); + packed[2] = F::from_u64(borrow); + } + for (idx, col_id) in cols.addr_bits.clone().enumerate() { + z[bus.bus_cell(col_id, j)] = packed[idx]; + } + } + } + + Ok(z) +} + +#[test] +fn riscv_trace_wiring_ccs_no_shared_cpu_bus_shout_sub_semantics_redteam() { + // Program: + // - LUI x1, 0 + // - LUI x2, 1 + // - SUB x3, x1, x2 + // - HALT + let program = vec![ + RiscvInstruction::Lui { rd: 1, imm: 0 }, + RiscvInstruction::Lui { rd: 2, imm: 1 }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sub, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Params + committer. + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n.max(ccs.m)).expect("params"); + params.k_rho = 16; + let l = setup_ajtai_committer(¶ms, ccs.m); + let mixers = default_mixers(); + + // Main CPU trace witness commitment (honest). + let z_cpu: Vec = x.iter().copied().chain(w.iter().copied()).collect(); + let Z_cpu = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &z_cpu); + let c_cpu = l.commit(&Z_cpu); + let mcs = ( + McsInstance { + c: c_cpu, + x: x.clone(), + m_in: layout.m_in, + }, + McsWitness { w, Z: Z_cpu }, + ); + + // Shout instance: SUB table, 1 lane (tamper packed borrow bit). + let t = exec.rows.len(); + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sub).0]; + let shout_lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + + let sub_lut_inst = LutInstance:: { + table_id: 0, + comms: Vec::new(), + k: 0, + d: 3, + n_side: 2, + steps: t, + lanes: 1, + ell: 1, + table_spec: Some(LutTableSpec::RiscvOpcodePacked { + opcode: RiscvOpcode::Sub, + xlen: 32, + }), + table: Vec::new(), + addr_group: None, + selector_group: None, + }; + let mut sub_z = build_shout_only_bus_z( + ccs.m, + layout.m_in, + t, + /*ell_addr=*/ sub_lut_inst.d * sub_lut_inst.ell, + /*lanes=*/ 1, + &shout_lanes, + &x, + ) + .expect("SUB Shout z"); + + let j = shout_lanes[0] + .has_lookup + .iter() + .position(|&b| b) + .expect("expected at least one SUB lookup"); + let bus = build_bus_layout_for_instances_with_shout_and_twist_lanes( + ccs.m, + layout.m_in, + t, + core::iter::once((/*ell_addr=*/ 3usize, /*lanes=*/ 1usize)), + core::iter::empty::<(usize, usize)>(), + ) + .expect("bus layout"); + let cols = &bus.shout_cols[0].lanes[0]; + let borrow_col_id = cols + .addr_bits + .clone() + .nth(2) + .expect("expected addr_bits[2] for borrow bit"); + let cell = bus.bus_cell(borrow_col_id, j); + sub_z[cell] = if sub_z[cell] == F::ONE { F::ZERO } else { F::ONE }; + + let sub_Z = neo_memory::ajtai::encode_vector_balanced_to_mat(¶ms, &sub_z); + let sub_c = l.commit(&sub_Z); + let sub_lut_inst = LutInstance:: { + table_id: 0, + comms: vec![sub_c], + ..sub_lut_inst + }; + let sub_lut_wit = LutWitness { mats: vec![sub_Z] }; + + let steps_witness = vec![StepWitnessBundle { + mcs, + lut_instances: vec![(sub_lut_inst, sub_lut_wit)], + mem_instances: Vec::new(), + _phantom: PhantomData, + }]; + let steps_instance: Vec> = + steps_witness.iter().map(StepInstanceBundle::from).collect(); + + // The prover may either reject because witness is invalid, or emit proof that fails verification. + let mut tr_prove = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sub-semantics-redteam"); + if let Ok(proof) = fold_shard_prove( + FoldingMode::Optimized, + &mut tr_prove, + ¶ms, + &ccs, + &steps_witness, + &[], + &[], + &l, + mixers, + ) { + let mut tr_verify = Poseidon2Transcript::new(b"riscv-trace-no-shared-bus-shout-sub-semantics-redteam"); + fold_shard_verify( + FoldingMode::Optimized, + &mut tr_verify, + ¶ms, + &ccs, + &steps_instance, + &[], + &proof, + mixers, + ) + .expect_err("tampered packed SUB borrow bit must be caught by Route-A time constraints"); + } +} diff --git a/crates/neo-fold/tests/shout_identity_u32_range_check.rs b/crates/neo-fold/tests/suites/trace_shout/shout_identity_u32_range_check.rs similarity index 95% rename from crates/neo-fold/tests/shout_identity_u32_range_check.rs rename to crates/neo-fold/tests/suites/trace_shout/shout_identity_u32_range_check.rs index 7620be94..4768b6ab 100644 --- a/crates/neo-fold/tests/shout_identity_u32_range_check.rs +++ b/crates/neo-fold/tests/suites/trace_shout/shout_identity_u32_range_check.rs @@ -88,7 +88,7 @@ fn write_shout_lane_row( #[test] fn route_a_shout_identity_u32_range_check_two_lanes_same_value_verifies() { let ccs = create_identity_ccs(TEST_M); - let mut session = FoldingSession::::new_ajtai_seeded(FoldingMode::PaperExact, &ccs, [7u8; 32]) + let mut session = FoldingSession::::new_ajtai_seeded(FoldingMode::Optimized, &ccs, [7u8; 32]) .expect("new_ajtai_seeded"); let params = session.params().clone(); @@ -96,6 +96,7 @@ fn route_a_shout_identity_u32_range_check_two_lanes_same_value_verifies() { let x: u64 = 0x1234_5678; let inst = LutInstance:: { + table_id: 0, comms: Vec::new(), k: 0, d: 32, @@ -105,6 +106,8 @@ fn route_a_shout_identity_u32_range_check_two_lanes_same_value_verifies() { ell: 1, table_spec: Some(LutTableSpec::IdentityU32), table: vec![], + addr_group: None, + selector_group: None, }; let wit = LutWitness { mats: Vec::new() }; @@ -135,7 +138,7 @@ fn route_a_shout_identity_u32_range_check_two_lanes_same_value_verifies() { #[test] fn route_a_shout_identity_u32_range_check_rejects_wrong_val() { let ccs = create_identity_ccs(TEST_M); - let mut session = FoldingSession::::new_ajtai_seeded(FoldingMode::PaperExact, &ccs, [8u8; 32]) + let mut session = FoldingSession::::new_ajtai_seeded(FoldingMode::Optimized, &ccs, [8u8; 32]) .expect("new_ajtai_seeded"); let params = session.params().clone(); @@ -143,6 +146,7 @@ fn route_a_shout_identity_u32_range_check_rejects_wrong_val() { let bad: u64 = x.wrapping_add(5); let inst = LutInstance:: { + table_id: 0, comms: Vec::new(), k: 0, d: 32, @@ -152,6 +156,8 @@ fn route_a_shout_identity_u32_range_check_rejects_wrong_val() { ell: 1, table_spec: Some(LutTableSpec::IdentityU32), table: vec![], + addr_group: None, + selector_group: None, }; let wit = LutWitness { mats: Vec::new() }; diff --git a/crates/neo-fold/tests/shout_multi_lookup_implicit_table_spec.rs b/crates/neo-fold/tests/suites/trace_shout/shout_multi_lookup_implicit_table_spec.rs similarity index 100% rename from crates/neo-fold/tests/shout_multi_lookup_implicit_table_spec.rs rename to crates/neo-fold/tests/suites/trace_shout/shout_multi_lookup_implicit_table_spec.rs diff --git a/crates/neo-fold/tests/shout_multi_lookup_per_step.rs b/crates/neo-fold/tests/suites/trace_shout/shout_multi_lookup_per_step.rs similarity index 100% rename from crates/neo-fold/tests/shout_multi_lookup_per_step.rs rename to crates/neo-fold/tests/suites/trace_shout/shout_multi_lookup_per_step.rs diff --git a/crates/neo-fold/tests/shout_padded_binary_table.rs b/crates/neo-fold/tests/suites/trace_shout/shout_padded_binary_table.rs similarity index 100% rename from crates/neo-fold/tests/shout_padded_binary_table.rs rename to crates/neo-fold/tests/suites/trace_shout/shout_padded_binary_table.rs diff --git a/crates/neo-fold/tests/suites/trace_twist/mod.rs b/crates/neo-fold/tests/suites/trace_twist/mod.rs new file mode 100644 index 00000000..eb57ba54 --- /dev/null +++ b/crates/neo-fold/tests/suites/trace_twist/mod.rs @@ -0,0 +1,5 @@ +mod twist_lane_pinning; +mod twist_multi_write_per_step; +mod twist_shout_fibonacci_cycle_trace; +mod twist_shout_power_tests; +mod twist_shout_soundness; diff --git a/crates/neo-fold/tests/twist_lane_pinning.rs b/crates/neo-fold/tests/suites/trace_twist/twist_lane_pinning.rs similarity index 100% rename from crates/neo-fold/tests/twist_lane_pinning.rs rename to crates/neo-fold/tests/suites/trace_twist/twist_lane_pinning.rs diff --git a/crates/neo-fold/tests/twist_multi_write_per_step.rs b/crates/neo-fold/tests/suites/trace_twist/twist_multi_write_per_step.rs similarity index 100% rename from crates/neo-fold/tests/twist_multi_write_per_step.rs rename to crates/neo-fold/tests/suites/trace_twist/twist_multi_write_per_step.rs diff --git a/crates/neo-fold/tests/twist_shout_fibonacci_cycle_trace.rs b/crates/neo-fold/tests/suites/trace_twist/twist_shout_fibonacci_cycle_trace.rs similarity index 93% rename from crates/neo-fold/tests/twist_shout_fibonacci_cycle_trace.rs rename to crates/neo-fold/tests/suites/trace_twist/twist_shout_fibonacci_cycle_trace.rs index dbbfe284..c11062df 100644 --- a/crates/neo-fold/tests/twist_shout_fibonacci_cycle_trace.rs +++ b/crates/neo-fold/tests/suites/trace_twist/twist_shout_fibonacci_cycle_trace.rs @@ -41,7 +41,7 @@ //! value stored in Twist memory `mem[0]`, and we attach an output-binding proof so the verifier //! checks that claim (not just that the execution is internally consistent). -#[path = "common/fib_twist_shout_vm.rs"] +#[path = "../../common/fib_twist_shout_vm.rs"] mod fib_twist_shout_vm; use std::collections::HashMap; @@ -409,8 +409,10 @@ fn twist_shout_fibonacci_cycle_trace() { .unwrap_or(0) ); println!( - "mem_sidecar: cpu_me_claims_val={} proofs={}", - step_proof.mem.cpu_me_claims_val.len(), + "mem_sidecar: val_me_claims={} wb_me_claims={} wp_me_claims={} proofs={}", + step_proof.mem.val_me_claims.len(), + step_proof.mem.wb_me_claims.len(), + step_proof.mem.wp_me_claims.len(), step_proof.mem.proofs.len() ); println!( @@ -472,15 +474,37 @@ fn twist_shout_fibonacci_cycle_trace() { } } - if let Some(val_fold) = &step_proof.val_fold { + if step_proof.val_fold.is_empty() { + println!("val_lane: "); + } else { + let total_children: usize = step_proof + .val_fold + .iter() + .map(|p| p.dec_children.len()) + .sum(); println!( - "val_lane: rlc_rhos={} dec_children={}", - val_fold.rlc_rhos.len(), - val_fold.dec_children.len() + "val_lane: proofs={} total_dec_children={}", + step_proof.val_fold.len(), + total_children ); - } else { - println!("val_lane: "); } + let wb_children: usize = step_proof + .wb_fold + .iter() + .map(|p| p.dec_children.len()) + .sum(); + let wp_children: usize = step_proof + .wp_fold + .iter() + .map(|p| p.dec_children.len()) + .sum(); + println!( + "wb_lane: proofs={} total_dec_children={} | wp_lane: proofs={} total_dec_children={}", + step_proof.wb_fold.len(), + wb_children, + step_proof.wp_fold.len(), + wp_children + ); } } } diff --git a/crates/neo-fold/tests/twist_shout_power_tests.rs b/crates/neo-fold/tests/suites/trace_twist/twist_shout_power_tests.rs similarity index 98% rename from crates/neo-fold/tests/twist_shout_power_tests.rs rename to crates/neo-fold/tests/suites/trace_twist/twist_shout_power_tests.rs index 6bd1a341..9d466f76 100644 --- a/crates/neo-fold/tests/twist_shout_power_tests.rs +++ b/crates/neo-fold/tests/suites/trace_twist/twist_shout_power_tests.rs @@ -1,7 +1,7 @@ #![allow(non_snake_case)] #![allow(deprecated)] -#[path = "common/fixtures.rs"] +#[path = "../../common/fixtures.rs"] mod fixtures; use fixtures::{ @@ -108,7 +108,7 @@ fn redteam_drop_val_fold_must_fail() { let fx = build_twist_shout_2step_fixture(3); let mut proof = prove(FoldingMode::Optimized, &fx); - proof.steps[0].val_fold = None; + proof.steps[0].val_fold.clear(); assert!( verify(FoldingMode::Optimized, &fx, &proof).is_err(), diff --git a/crates/neo-fold/tests/twist_shout_soundness.rs b/crates/neo-fold/tests/suites/trace_twist/twist_shout_soundness.rs similarity index 98% rename from crates/neo-fold/tests/twist_shout_soundness.rs rename to crates/neo-fold/tests/suites/trace_twist/twist_shout_soundness.rs index b7f35988..e02be54b 100644 --- a/crates/neo-fold/tests/twist_shout_soundness.rs +++ b/crates/neo-fold/tests/suites/trace_twist/twist_shout_soundness.rs @@ -1,7 +1,7 @@ #![allow(non_snake_case)] #![allow(deprecated)] -#[path = "common/fixtures.rs"] +#[path = "../../common/fixtures.rs"] mod fixtures; use neo_ajtai::{AjtaiSModule, Commitment as Cmt}; @@ -416,7 +416,8 @@ fn tamper_batched_time_static_claim_sum_nonzero_fails() { let dims = utils::build_dims_and_policy(¶ms, &ccs).expect("dims"); let step_inst = StepInstanceBundle::from(&step_bundle); - let metas = RouteATimeClaimPlan::time_claim_metas_for_step(&step_inst, dims.d_sc, None); + let metas = + RouteATimeClaimPlan::time_claim_metas_for_step(&step_inst, dims.d_sc, false, false, false, false, false, None); let static_idx = metas .iter() .enumerate() diff --git a/crates/neo-fold/tests/suites/vm/mod.rs b/crates/neo-fold/tests/suites/vm/mod.rs new file mode 100644 index 00000000..ec27617f --- /dev/null +++ b/crates/neo-fold/tests/suites/vm/mod.rs @@ -0,0 +1,4 @@ +mod riscv_chunk_size_auto; +mod riscv_exec_table_extraction; +mod test_riscv_wasm_demo_memory; +mod vm_opcode_dispatch_tests; diff --git a/crates/neo-fold/tests/suites/vm/riscv_chunk_size_auto.rs b/crates/neo-fold/tests/suites/vm/riscv_chunk_size_auto.rs new file mode 100644 index 00000000..87f30a45 --- /dev/null +++ b/crates/neo-fold/tests/suites/vm/riscv_chunk_size_auto.rs @@ -0,0 +1,29 @@ +#![allow(non_snake_case)] + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; + +#[test] +fn rv32_trace_chunk_rows_auto_prove_verify() { + // Small halting program. + let program: Vec = (0..9) + .map(|i| RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 1, + imm: (i + 1) as i32, + }) + .chain(std::iter::once(RiscvInstruction::Halt)) + .collect(); + + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .max_steps(program.len()) + .prove() + .expect("prove"); + + run.verify().expect("verify"); + assert!(run.trace_len() > 0); + assert!(run.fold_count() > 0); +} diff --git a/crates/neo-fold/tests/suites/vm/riscv_exec_table_extraction.rs b/crates/neo-fold/tests/suites/vm/riscv_exec_table_extraction.rs new file mode 100644 index 00000000..fe2a8706 --- /dev/null +++ b/crates/neo-fold/tests/suites/vm/riscv_exec_table_extraction.rs @@ -0,0 +1,131 @@ +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::exec_table::{ + Rv32MEventTable, Rv32RamEventKind, Rv32RamEventTable, Rv32RegEventKind, Rv32RegEventTable, +}; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; +use std::collections::HashMap; + +#[test] +fn exec_table_extracts_from_trace_run_and_pads() { + // Program exercises: + // - REG reads (rs1/rs2) on every step + // - ALU op in the middle of the trace + // - RAM store/load + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 3, + }, // x1 = 3 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 4, + }, // x2 = 4 + RiscvInstruction::RAlu { + op: RiscvOpcode::Add, + rd: 3, + rs1: 1, + rs2: 2, + }, // x3 = 7 + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 3, + imm: 0, + }, // mem[0] = x3 + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 4, + rs1: 0, + imm: 0, + }, // x4 = mem[0] + RiscvInstruction::Halt, + ]; + + let program_bytes = encode_program(&program); + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(32) + .min_trace_len(8) + .chunk_rows(4) + .max_steps(program.len()) + .shout_auto_minimal() + .prove() + .expect("prove"); + run.verify().expect("verify"); + + let exec = run.exec_table(); + assert_eq!(exec.rows.len(), 8); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows are empty"); + exec.validate_halted_tail().expect("halted tail"); + + let active = exec.rows.iter().filter(|r| r.active).count(); + assert_eq!(active, program.len()); + + // Padded rows must be inactive and have no fetched/proven events. + for r in exec.rows.iter().skip(active) { + assert!(!r.active); + assert!(r.prog_read.is_none()); + assert!(r.reg_read_lane0.is_none()); + assert!(r.reg_read_lane1.is_none()); + assert!(r.reg_write_lane0.is_none()); + assert!(r.ram_events.is_empty()); + assert!(r.shout_events.is_empty()); + } + + // Validate regfile/RAM semantics against the statement initial memory. + let init_regs: HashMap = HashMap::new(); + let init_ram: HashMap = HashMap::new(); + exec.validate_regfile_semantics(&init_regs) + .expect("regfile semantics"); + exec.validate_ram_semantics(&init_ram) + .expect("ram semantics"); + + // Extract reg/RAM event tables (sparse-over-time representation). + let reg_table = Rv32RegEventTable::from_exec_table(&exec, &init_regs).expect("reg event table"); + assert_eq!(reg_table.rows.len(), 16); // 2 reads per row + 4 writes + assert_eq!( + reg_table + .rows + .iter() + .filter(|r| r.kind == Rv32RegEventKind::ReadLane0) + .count(), + active + ); + assert_eq!( + reg_table + .rows + .iter() + .filter(|r| r.kind == Rv32RegEventKind::ReadLane1) + .count(), + active + ); + assert_eq!( + reg_table + .rows + .iter() + .filter(|r| r.kind == Rv32RegEventKind::WriteLane0) + .count(), + 4 + ); + + let ram_table = Rv32RamEventTable::from_exec_table(&exec, &init_ram).expect("ram event table"); + assert_eq!(ram_table.rows.len(), 2); + assert!(ram_table + .rows + .iter() + .any(|r| { r.kind == Rv32RamEventKind::Write && r.addr == 0 && r.prev_val == 0 && r.next_val == 7 })); + assert!(ram_table + .rows + .iter() + .any(|r| { r.kind == Rv32RamEventKind::Read && r.addr == 0 && r.prev_val == 7 && r.next_val == 7 })); + + // No RV32M ops in this program. + let m = Rv32MEventTable::from_exec_table(&exec).expect("rv32m event table"); + assert_eq!(m.rows.len(), 0); +} diff --git a/crates/neo-fold/tests/riscv_wasm_demo/mini_asm.rs b/crates/neo-fold/tests/suites/vm/riscv_wasm_demo/mini_asm.rs similarity index 100% rename from crates/neo-fold/tests/riscv_wasm_demo/mini_asm.rs rename to crates/neo-fold/tests/suites/vm/riscv_wasm_demo/mini_asm.rs diff --git a/crates/neo-fold/tests/riscv_wasm_demo/mod.rs b/crates/neo-fold/tests/suites/vm/riscv_wasm_demo/mod.rs similarity index 100% rename from crates/neo-fold/tests/riscv_wasm_demo/mod.rs rename to crates/neo-fold/tests/suites/vm/riscv_wasm_demo/mod.rs diff --git a/crates/neo-fold/tests/riscv_wasm_demo/rv32_fibonacci.asm b/crates/neo-fold/tests/suites/vm/riscv_wasm_demo/rv32_fibonacci.asm similarity index 100% rename from crates/neo-fold/tests/riscv_wasm_demo/rv32_fibonacci.asm rename to crates/neo-fold/tests/suites/vm/riscv_wasm_demo/rv32_fibonacci.asm diff --git a/crates/neo-fold/tests/test_riscv_wasm_demo_memory.rs b/crates/neo-fold/tests/suites/vm/test_riscv_wasm_demo_memory.rs similarity index 81% rename from crates/neo-fold/tests/test_riscv_wasm_demo_memory.rs rename to crates/neo-fold/tests/suites/vm/test_riscv_wasm_demo_memory.rs index 0ac0ddd0..dccc5891 100644 --- a/crates/neo-fold/tests/test_riscv_wasm_demo_memory.rs +++ b/crates/neo-fold/tests/suites/vm/test_riscv_wasm_demo_memory.rs @@ -1,8 +1,9 @@ #![allow(non_snake_case)] +#[path = "riscv_wasm_demo/mod.rs"] mod riscv_wasm_demo; -use neo_fold::riscv_shard::Rv32B1; +use neo_fold::riscv_trace_shard::Rv32TraceWiring; use neo_math::F; use p3_field::PrimeCharacteristicRing; @@ -30,7 +31,7 @@ fn env_u32(name: &str, default: u32) -> u32 { fn test_rv32_fibonacci_mini_asm_peak_rss() { let n = env_u32("NEO_RV32_N", 5); let ram_bytes = env_usize("NEO_RV32_RAM_BYTES", 2048); - let chunk_size = env_usize("NEO_RV32_CHUNK_SIZE", 128); + let chunk_rows = env_usize("NEO_RV32_CHUNK_SIZE", 128); let max_steps = env_usize("NEO_RV32_MAX_STEPS", 0); let asm = include_str!("riscv_wasm_demo/rv32_fibonacci.asm"); @@ -48,11 +49,10 @@ fn test_rv32_fibonacci_mini_asm_peak_rss() { ); let mut run = { - let mut b = Rv32B1::from_rom(/*program_base=*/ 0, &program_bytes) + let mut b = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) .xlen(32) - .ram_bytes(ram_bytes) .ram_init_u32(/*addr=*/ 0x104, n) - .chunk_size(chunk_size) + .chunk_rows(chunk_rows) .shout_auto_minimal() .output(/*output_addr=*/ 0x100, /*expected_output=*/ expected_f); if max_steps > 0 { @@ -81,22 +81,18 @@ fn test_rv32_fibonacci_mini_asm_peak_rss() { .unwrap_or(0.0) ); - let trace_len = run.riscv_trace_len().ok(); + let trace_len = run.trace_len(); println!( - "rv32-fib: n={n} expected={expected} trace_len={:?} folds={} chunk_size={} ram_bytes={}", + "rv32-fib: n={n} expected={expected} trace_len={} folds={} chunk_rows={} ram_bytes={}", trace_len, run.fold_count(), - chunk_size, + chunk_rows, ram_bytes ); println!( - "rv32-fib: ccs_constraints={} ccs_variables={} shout_lookups={:?}", + "rv32-fib: ccs_constraints={} ccs_variables={} shout_tables={}", run.ccs_num_constraints(), run.ccs_num_variables(), - run.shout_lookup_count().ok() + run.used_shout_table_ids().len() ); - - assert!(run - .verify_default_output_claim() - .expect("verify output claim")); } diff --git a/crates/neo-fold/tests/vm_opcode_dispatch_tests.rs b/crates/neo-fold/tests/suites/vm/vm_opcode_dispatch_tests.rs similarity index 95% rename from crates/neo-fold/tests/vm_opcode_dispatch_tests.rs rename to crates/neo-fold/tests/suites/vm/vm_opcode_dispatch_tests.rs index 30d5239d..fe13dba9 100644 --- a/crates/neo-fold/tests/vm_opcode_dispatch_tests.rs +++ b/crates/neo-fold/tests/suites/vm/vm_opcode_dispatch_tests.rs @@ -59,16 +59,7 @@ fn setup_ajtai_pp(m: usize, seed: u64) -> AjtaiSModule { } fn default_mixers() -> Mixers { - fn mix_rhos_commits(_rhos: &[Mat], _cs: &[Cmt]) -> Cmt { - Cmt::zeros(D, 1) - } - fn combine_b_pows(_cs: &[Cmt], _b: u32) -> Cmt { - Cmt::zeros(D, 1) - } - CommitMixers { - mix_rhos_commits, - combine_b_pows, - } + crate::common_setup::default_mixers() } fn create_identity_ccs(n: usize) -> CcsStructure { @@ -268,6 +259,7 @@ fn empty_mem_trace() -> PlainMemTrace { } fn metadata_only_mem_instance( + mem_id: u32, layout: &PlainMemLayout, init: MemInit, steps: usize, @@ -275,6 +267,7 @@ fn metadata_only_mem_instance( let ell = layout.n_side.trailing_zeros() as usize; ( MemInstance { + mem_id, comms: Vec::new(), k: layout.k, d: layout.d, @@ -292,6 +285,7 @@ fn metadata_only_lut_instance(table: &LutTable, steps: usize) -> (LutInstance let ell = table.n_side.trailing_zeros() as usize; ( LutInstance { + table_id: table.table_id, comms: Vec::new(), k: table.k, d: table.d, @@ -301,6 +295,8 @@ fn metadata_only_lut_instance(table: &LutTable, steps: usize) -> (LutInstance ell, table_spec: None, table: table.content.clone(), + addr_group: None, + selector_group: None, }, LutWitness { mats: Vec::new() }, ) @@ -362,7 +358,7 @@ fn vm_simple_add_program() { lanes: 1, }; let mem_trace = empty_mem_trace(); - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, MemInit::Zero, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, MemInit::Zero, mem_trace.steps); let (opcode_inst, opcode_wit) = metadata_only_lut_instance(&bytecode_table, opcode_trace.has_lookup.len()); let (imm_inst, imm_wit) = metadata_only_lut_instance(&imm_table, imm_trace.has_lookup.len()); @@ -385,7 +381,7 @@ fn vm_simple_add_program() { let mut tr_prove = Poseidon2Transcript::new(b"vm-add-program"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -400,7 +396,7 @@ fn vm_simple_add_program() { let mut tr_verify = Poseidon2Transcript::new(b"vm-add-program"); let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -449,7 +445,7 @@ fn vm_register_file_operations() { write_val: vec![F::from_u64(10)], inc_at_write_addr: vec![F::from_u64(10)], }; - let (reg_inst, reg_wit) = metadata_only_mem_instance(®_layout, MemInit::Zero, reg_trace.steps); + let (reg_inst, reg_wit) = metadata_only_mem_instance(0, ®_layout, MemInit::Zero, reg_trace.steps); let mem_bus = [(®_inst, ®_trace)]; let (mcs, mcs_wit) = create_mcs_with_bus(¶ms, &ccs, &l, 0, &[], &mem_bus); @@ -476,7 +472,7 @@ fn vm_register_file_operations() { // State after step 0: R0=10 let reg_init = MemInit::Sparse(vec![(0, F::from_u64(10))]); - let (reg_inst, reg_wit) = metadata_only_mem_instance(®_layout, reg_init, reg_trace.steps); + let (reg_inst, reg_wit) = metadata_only_mem_instance(0, ®_layout, reg_init, reg_trace.steps); let mem_bus = [(®_inst, ®_trace)]; let (mcs, mcs_wit) = create_mcs_with_bus(¶ms, &ccs, &l, 1, &[], &mem_bus); @@ -504,7 +500,7 @@ fn vm_register_file_operations() { // State after step 1: R0=10, R1=20 let reg_init = MemInit::Sparse(vec![(0, F::from_u64(10)), (1, F::from_u64(20))]); - let (reg_inst, reg_wit) = metadata_only_mem_instance(®_layout, reg_init, reg_trace.steps); + let (reg_inst, reg_wit) = metadata_only_mem_instance(0, ®_layout, reg_init, reg_trace.steps); let mem_bus = [(®_inst, ®_trace)]; let (mcs, mcs_wit) = create_mcs_with_bus(¶ms, &ccs, &l, 2, &[], &mem_bus); @@ -521,7 +517,7 @@ fn vm_register_file_operations() { let mut tr_prove = Poseidon2Transcript::new(b"vm-register-file"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -536,7 +532,7 @@ fn vm_register_file_operations() { let mut tr_verify = Poseidon2Transcript::new(b"vm-register-file"); let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -603,7 +599,7 @@ fn vm_combined_bytecode_and_data_memory() { inc_at_write_addr: vec![F::from_u64(42)], }; let (bytecode_inst, bytecode_wit) = metadata_only_lut_instance(&bytecode, bytecode_trace.has_lookup.len()); - let (ram_inst, ram_wit) = metadata_only_mem_instance(&ram_layout, MemInit::Zero, ram_trace.steps); + let (ram_inst, ram_wit) = metadata_only_mem_instance(0, &ram_layout, MemInit::Zero, ram_trace.steps); let lut_bus = [(&bytecode_inst, &bytecode_trace)]; let mem_bus = [(&ram_inst, &ram_trace)]; @@ -621,7 +617,7 @@ fn vm_combined_bytecode_and_data_memory() { let mut tr_prove = Poseidon2Transcript::new(b"vm-combined-rom-ram"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -636,7 +632,7 @@ fn vm_combined_bytecode_and_data_memory() { let mut tr_verify = Poseidon2Transcript::new(b"vm-combined-rom-ram"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -688,7 +684,7 @@ fn vm_invalid_opcode_claim_fails() { let mut tr_prove = Poseidon2Transcript::new(b"vm-invalid-opcode-claim"); let proof_result = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -704,7 +700,7 @@ fn vm_invalid_opcode_claim_fails() { let steps_public = [StepInstanceBundle::from(&step_bundle)]; assert!( fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, @@ -763,7 +759,7 @@ fn vm_multi_instruction_sequence() { }; let mem_trace = empty_mem_trace(); - let (mem_inst, mem_wit) = metadata_only_mem_instance(&mem_layout, MemInit::Zero, mem_trace.steps); + let (mem_inst, mem_wit) = metadata_only_mem_instance(0, &mem_layout, MemInit::Zero, mem_trace.steps); let (bytecode_inst, bytecode_wit) = metadata_only_lut_instance(&bytecode, bytecode_trace.has_lookup.len()); let lut_bus = [(&bytecode_inst, &bytecode_trace)]; @@ -783,7 +779,7 @@ fn vm_multi_instruction_sequence() { let mut tr_prove = Poseidon2Transcript::new(b"vm-multi-instr"); let proof = fold_shard_prove( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_prove, ¶ms, &ccs, @@ -798,7 +794,7 @@ fn vm_multi_instruction_sequence() { let mut tr_verify = Poseidon2Transcript::new(b"vm-multi-instr"); let steps_public: Vec> = steps.iter().map(StepInstanceBundle::from).collect(); let _ = fold_shard_verify( - FoldingMode::PaperExact, + FoldingMode::Optimized, &mut tr_verify, ¶ms, &ccs, diff --git a/crates/neo-fold/tests/trace_shout.rs b/crates/neo-fold/tests/trace_shout.rs new file mode 100644 index 00000000..70d32732 --- /dev/null +++ b/crates/neo-fold/tests/trace_shout.rs @@ -0,0 +1,5 @@ +#[path = "common/setup.rs"] +mod common_setup; + +#[path = "suites/trace_shout/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/trace_twist.rs b/crates/neo-fold/tests/trace_twist.rs new file mode 100644 index 00000000..021251c0 --- /dev/null +++ b/crates/neo-fold/tests/trace_twist.rs @@ -0,0 +1,5 @@ +#[path = "common/setup.rs"] +mod common_setup; + +#[path = "suites/trace_twist/mod.rs"] +mod suite; diff --git a/crates/neo-fold/tests/vm.rs b/crates/neo-fold/tests/vm.rs new file mode 100644 index 00000000..64f35d89 --- /dev/null +++ b/crates/neo-fold/tests/vm.rs @@ -0,0 +1,5 @@ +#[path = "common/setup.rs"] +mod common_setup; + +#[path = "suites/vm/mod.rs"] +mod suite; diff --git a/crates/neo-memory/Cargo.toml b/crates/neo-memory/Cargo.toml index 9b407083..08365ec6 100644 --- a/crates/neo-memory/Cargo.toml +++ b/crates/neo-memory/Cargo.toml @@ -24,5 +24,9 @@ p3-field = { workspace = true } p3-matrix = { workspace = true } p3-goldilocks = { workspace = true } +[dev-dependencies] +rand = { workspace = true } +rand_chacha = { workspace = true } + [lints] workspace = true diff --git a/crates/neo-memory/src/addr.rs b/crates/neo-memory/src/addr.rs index 3fe651b7..1183962d 100644 --- a/crates/neo-memory/src/addr.rs +++ b/crates/neo-memory/src/addr.rs @@ -108,6 +108,35 @@ pub fn validate_pow2_bit_addressing_shape(proto: &'static str, n_side: usize, el pub fn validate_shout_bit_addressing(inst: &LutInstance) -> Result<(), PiCcsError> { // Virtual/implicit tables may not have a materialized `k = n_side^d` table. if let Some(spec) = &inst.table_spec { + let rv32_packed_expected_d = |opcode: crate::riscv::lookups::RiscvOpcode| -> Result { + Ok(match opcode { + crate::riscv::lookups::RiscvOpcode::And + | crate::riscv::lookups::RiscvOpcode::Andn + | crate::riscv::lookups::RiscvOpcode::Xor + | crate::riscv::lookups::RiscvOpcode::Or => 34usize, + crate::riscv::lookups::RiscvOpcode::Add | crate::riscv::lookups::RiscvOpcode::Sub => 3usize, + crate::riscv::lookups::RiscvOpcode::Eq | crate::riscv::lookups::RiscvOpcode::Neq => 35usize, + crate::riscv::lookups::RiscvOpcode::Slt => 37usize, + crate::riscv::lookups::RiscvOpcode::Sll => 38usize, + crate::riscv::lookups::RiscvOpcode::Srl => 38usize, + crate::riscv::lookups::RiscvOpcode::Sra => 38usize, + crate::riscv::lookups::RiscvOpcode::Sltu => 35usize, + crate::riscv::lookups::RiscvOpcode::Mul => 34usize, + crate::riscv::lookups::RiscvOpcode::Mulh => 38usize, + crate::riscv::lookups::RiscvOpcode::Mulhu => 34usize, + crate::riscv::lookups::RiscvOpcode::Mulhsu => 37usize, + crate::riscv::lookups::RiscvOpcode::Div => 43usize, + crate::riscv::lookups::RiscvOpcode::Divu => 38usize, + crate::riscv::lookups::RiscvOpcode::Rem => 43usize, + crate::riscv::lookups::RiscvOpcode::Remu => 38usize, + _ => { + return Err(PiCcsError::InvalidInput(format!( + "Shout(RISC-V packed): unsupported opcode={opcode:?}" + ))); + } + }) + }; + validate_pow2_bit_addressing_shape("Shout", inst.n_side, inst.ell)?; if inst.k != 0 { return Err(PiCcsError::InvalidInput( @@ -139,6 +168,68 @@ pub fn validate_shout_bit_addressing(inst: &LutInstance) -> Resu ))); } } + LutTableSpec::RiscvOpcodePacked { opcode, xlen } => { + if *xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "Shout(RISC-V packed): expected xlen=32, got xlen={xlen}" + ))); + } + let expected_d = rv32_packed_expected_d(*opcode)?; + // Packed-key Shout lanes are not bit-addressed: we repurpose the addr-bit slice as + // `[lhs_u32, rhs_u32, aux...]` and keep `[has_lookup, val_u32]`. + if inst.n_side != 2 || inst.ell != 1 { + return Err(PiCcsError::InvalidInput(format!( + "Shout(RISC-V packed): expected n_side=2, ell=1, got n_side={}, ell={}", + inst.n_side, inst.ell + ))); + } + if inst.d != expected_d { + return Err(PiCcsError::InvalidInput(format!( + "Shout(RISC-V packed): expected d={expected_d}, got d={}", + inst.d, + ))); + } + } + LutTableSpec::RiscvOpcodeEventTablePacked { + opcode, + xlen, + time_bits, + } => { + if *xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "Shout(RISC-V event-table packed): expected xlen=32, got xlen={xlen}" + ))); + } + if *time_bits == 0 { + return Err(PiCcsError::InvalidInput( + "Shout(RISC-V event-table packed): time_bits must be >= 1".into(), + )); + } + if *time_bits > 64 { + return Err(PiCcsError::InvalidInput(format!( + "Shout(RISC-V event-table packed): time_bits={time_bits} too large (max 64)" + ))); + } + let base_d = rv32_packed_expected_d(*opcode)?; + let expected_d = time_bits + .checked_add(base_d) + .ok_or_else(|| PiCcsError::InvalidInput("Shout(RISC-V event-table packed): d overflow".into()))?; + + // Event-table packed Shout lanes are not bit-addressed: addr_bits is repurposed as + // `[time_bits_le, lhs_u32, rhs_u32, aux...]` and we keep `[has_lookup, val_u32]`. + if inst.n_side != 2 || inst.ell != 1 { + return Err(PiCcsError::InvalidInput(format!( + "Shout(RISC-V event-table packed): expected n_side=2, ell=1, got n_side={}, ell={}", + inst.n_side, inst.ell + ))); + } + if inst.d != expected_d { + return Err(PiCcsError::InvalidInput(format!( + "Shout(RISC-V event-table packed): expected d={expected_d} (= time_bits({time_bits}) + base_d({base_d})), got d={}", + inst.d, + ))); + } + } LutTableSpec::IdentityU32 => { if inst.n_side != 2 || inst.ell != 1 { return Err(PiCcsError::InvalidInput(format!( diff --git a/crates/neo-memory/src/builder.rs b/crates/neo-memory/src/builder.rs index ed35c991..4898df9d 100644 --- a/crates/neo-memory/src/builder.rs +++ b/crates/neo-memory/src/builder.rs @@ -26,6 +26,24 @@ pub trait CpuArithmetization { ) -> Result, McsWitness)>, Self::Error> { self.build_ccs_chunks(trace, 1) } + + /// Per-table address-sharing group ids for bus layout column sharing. + /// + /// Tables with the same group id share `addr_bits` columns in the bus layout. + /// Default: empty (no sharing). Override in trace mode for column efficiency. + fn shout_addr_groups(&self) -> &HashMap { + static EMPTY: std::sync::LazyLock> = std::sync::LazyLock::new(HashMap::new); + &EMPTY + } + + /// Per-table selector-sharing group ids for bus layout column sharing. + /// + /// Tables with the same group id share `has_lookup` columns in the bus layout. + /// Default: empty (no sharing). Override in trace mode for column efficiency. + fn shout_selector_groups(&self) -> &HashMap { + static EMPTY: std::sync::LazyLock> = std::sync::LazyLock::new(HashMap::new); + &EMPTY + } } #[derive(Debug)] @@ -106,12 +124,42 @@ where Ok(bundles) } -/// Like `build_shard_witness_shared_cpu_bus`, but also returns auxiliary outputs useful for -/// higher-level APIs (e.g. output binding that needs the terminal Twist memory state). -pub fn build_shard_witness_shared_cpu_bus_with_aux( - vm: V, - twist: Tw, - shout: Sh, +/// Build shard witness bundles for **shared CPU bus** mode from an already-executed VM trace. +/// +/// This is equivalent to `build_shard_witness_shared_cpu_bus(...)`, but avoids re-running the VM +/// when the caller already has a `VmTrace` available. +pub fn build_shard_witness_shared_cpu_bus_from_trace( + trace: &VmTrace, + max_steps: usize, + chunk_size: usize, + mem_layouts: &HashMap, + lut_tables: &HashMap>, + lut_table_specs: &HashMap, + lut_lanes: &HashMap, + initial_mem: &HashMap<(u32, u64), Goldilocks>, + cpu_arith: &A, +) -> Result>, ShardBuildError> +where + A: CpuArithmetization, +{ + let (bundles, _aux) = build_shard_witness_shared_cpu_bus_from_trace_with_aux( + trace, + max_steps, + chunk_size, + mem_layouts, + lut_tables, + lut_table_specs, + lut_lanes, + initial_mem, + cpu_arith, + )?; + Ok(bundles) +} + +/// Like `build_shard_witness_shared_cpu_bus_from_trace`, but also returns auxiliary outputs useful +/// for higher-level APIs (e.g. output binding that needs terminal Twist memory states). +pub fn build_shard_witness_shared_cpu_bus_from_trace_with_aux( + trace: &VmTrace, max_steps: usize, chunk_size: usize, mem_layouts: &HashMap, @@ -122,32 +170,18 @@ pub fn build_shard_witness_shared_cpu_bus_with_aux( cpu_arith: &A, ) -> Result<(Vec>, ShardWitnessAux), ShardBuildError> where - V: neo_vm_trace::VmCpu, - Tw: neo_vm_trace::Twist, - Sh: neo_vm_trace::Shout, A: CpuArithmetization, { if chunk_size == 0 { return Err(ShardBuildError::InvalidChunkSize("chunk_size must be >= 1".into())); } - - // 1) Run VM and collect the executed trace for this shard (up to `max_steps`). - // - // NOTE: We intentionally do **not** pad out to `max_steps` here. Padding is handled at the - // per-chunk level by the CPU arithmetization via `is_active`, so the last chunk may be partial. - // - // This keeps the proof size proportional to the executed trace length instead of the caller's - // safety bound. - let trace = neo_vm_trace::trace_program(vm, twist, shout, max_steps) - .map_err(|e| ShardBuildError::VmError(e.to_string()))?; - let original_len = trace.steps.len(); - let did_halt = trace.did_halt(); - debug_assert!( - original_len <= max_steps, - "trace_program must not exceed max_steps (got {}, max_steps={})", - original_len, - max_steps - ); + if trace.steps.len() > max_steps { + return Err(ShardBuildError::InvalidChunkSize(format!( + "trace length {} exceeds max_steps {}", + trace.steps.len(), + max_steps + ))); + } // Shared-bus mode does not support "silent dropping" of trace events: if the trace contains // Twist/Shout events, the corresponding instance metadata must be provided so the prover @@ -178,6 +212,8 @@ where } } + let original_len = trace.steps.len(); + let did_halt = trace.did_halt(); let steps_len = trace.steps.len(); let chunks_len = steps_len.div_ceil(chunk_size); @@ -194,7 +230,7 @@ where // 3) CPU arithmetization chunks. let mcss = cpu_arith - .build_ccs_chunks(&trace, chunk_size) + .build_ccs_chunks(trace, chunk_size) .map_err(|e| ShardBuildError::CcsError(e.to_string()))?; if mcss.len() != chunks_len { return Err(ShardBuildError::CcsError(format!( @@ -265,6 +301,7 @@ where let ell = ell_from_pow2_n_side(layout.n_side)?; let inst = MemInstance:: { + mem_id, comms: Vec::new(), k: layout.k, d: layout.d, @@ -325,6 +362,16 @@ where let ell = 1usize; (0usize, d, n_side, ell, Vec::new()) } + LutTableSpec::RiscvOpcodePacked { .. } => { + return Err(ShardBuildError::InvalidInit( + "RiscvOpcodePacked is not supported in the chunked builder path".into(), + )); + } + LutTableSpec::RiscvOpcodeEventTablePacked { .. } => { + return Err(ShardBuildError::InvalidInit( + "RiscvOpcodeEventTablePacked is not supported in the chunked builder path".into(), + )); + } LutTableSpec::IdentityU32 => (0usize, 32usize, 2usize, 1usize, Vec::new()), } } else { @@ -337,6 +384,7 @@ where let lanes = lut_lanes.get(&table_id).copied().unwrap_or(1).max(1); let inst = LutInstance:: { + table_id, comms: Vec::new(), k, d, @@ -346,6 +394,8 @@ where ell, table_spec, table, + addr_group: cpu_arith.shout_addr_groups().get(&table_id).copied(), + selector_group: cpu_arith.shout_selector_groups().get(&table_id).copied(), }; let wit = LutWitness { mats: Vec::new() }; lut_instances.push((inst, wit)); @@ -371,3 +421,50 @@ where }; Ok((step_bundles, aux)) } + +/// Like `build_shard_witness_shared_cpu_bus`, but also returns auxiliary outputs useful for +/// higher-level APIs (e.g. output binding that needs the terminal Twist memory state). +pub fn build_shard_witness_shared_cpu_bus_with_aux( + vm: V, + twist: Tw, + shout: Sh, + max_steps: usize, + chunk_size: usize, + mem_layouts: &HashMap, + lut_tables: &HashMap>, + lut_table_specs: &HashMap, + lut_lanes: &HashMap, + initial_mem: &HashMap<(u32, u64), Goldilocks>, + cpu_arith: &A, +) -> Result<(Vec>, ShardWitnessAux), ShardBuildError> +where + V: neo_vm_trace::VmCpu, + Tw: neo_vm_trace::Twist, + Sh: neo_vm_trace::Shout, + A: CpuArithmetization, +{ + if chunk_size == 0 { + return Err(ShardBuildError::InvalidChunkSize("chunk_size must be >= 1".into())); + } + + // 1) Run VM and collect the executed trace for this shard (up to `max_steps`). + // + // NOTE: We intentionally do **not** pad out to `max_steps` here. Padding is handled at the + // per-chunk level by the CPU arithmetization via `is_active`, so the last chunk may be partial. + // + // This keeps the proof size proportional to the executed trace length instead of the caller's + // safety bound. + let trace = neo_vm_trace::trace_program(vm, twist, shout, max_steps) + .map_err(|e| ShardBuildError::VmError(e.to_string()))?; + build_shard_witness_shared_cpu_bus_from_trace_with_aux( + &trace, + max_steps, + chunk_size, + mem_layouts, + lut_tables, + lut_table_specs, + lut_lanes, + initial_mem, + cpu_arith, + ) +} diff --git a/crates/neo-memory/src/cpu/bus_layout.rs b/crates/neo-memory/src/cpu/bus_layout.rs index e9aaea49..48fec635 100644 --- a/crates/neo-memory/src/cpu/bus_layout.rs +++ b/crates/neo-memory/src/cpu/bus_layout.rs @@ -1,4 +1,5 @@ use core::ops::Range; +use std::collections::HashMap; /// Canonical layout for the shared CPU bus tail inside the CPU witness `z`. /// @@ -16,7 +17,7 @@ use core::ops::Range; /// - lanes in order, each lane is: /// - `addr_bits[0..ell_addr)` /// - `has_lookup` -/// - `val` +/// - `vals[0..n_vals)` /// /// Within each Twist instance: /// - `ra_bits[0..ell_addr)` @@ -41,7 +42,15 @@ pub struct BusLayout { pub struct ShoutCols { pub addr_bits: Range, pub has_lookup: usize, - pub val: usize, + pub vals: Vec, +} + +impl ShoutCols { + #[inline] + pub fn primary_val(&self) -> usize { + debug_assert!(!self.vals.is_empty(), "ShoutCols must have at least one value column"); + self.vals[0] + } } #[derive(Clone, Debug)] @@ -49,10 +58,29 @@ pub struct ShoutInstanceCols { /// Lookup lanes for this Shout instance. /// /// Each lane has the canonical per-step bus slice: - /// `[addr_bits, has_lookup, val]`. + /// `[addr_bits, has_lookup, vals[0..n_vals)]`. pub lanes: Vec, } +#[derive(Clone, Copy, Debug)] +pub struct ShoutInstanceShape { + pub ell_addr: usize, + pub lanes: usize, + /// Number of value columns per lane. + pub n_vals: usize, + /// Optional address-sharing group id. + /// + /// Instances with the same `addr_group` reuse the same `addr_bits` columns + /// per lane and allocate only fresh `[has_lookup, val]` columns. + pub addr_group: Option, + /// Optional selector-sharing group id. + /// + /// Instances with the same `selector_group` reuse `has_lookup` columns per lane and + /// allocate only fresh `val` columns. This is safe only for families that are known + /// to have identical `has_lookup` patterns over time. + pub selector_group: Option, +} + #[derive(Clone, Debug)] pub struct TwistCols { pub ra_bits: Range, @@ -151,34 +179,102 @@ pub fn build_bus_layout_for_instances_with_shout_and_twist_lanes( chunk_size: usize, shout_ell_addrs_and_lanes: impl IntoIterator, twist_ell_addrs_and_lanes: impl IntoIterator, +) -> Result { + build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes( + m, + m_in, + chunk_size, + shout_ell_addrs_and_lanes + .into_iter() + .map(|(ell_addr, lanes)| ShoutInstanceShape { + ell_addr, + lanes, + n_vals: 1, + addr_group: None, + selector_group: None, + }), + twist_ell_addrs_and_lanes, + ) +} + +pub fn build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes( + m: usize, + m_in: usize, + chunk_size: usize, + shout_shapes: impl IntoIterator, + twist_ell_addrs_and_lanes: impl IntoIterator, ) -> Result { if chunk_size == 0 { return Err("BusLayout: chunk_size must be >= 1".into()); } let mut col = 0usize; + let mut shared_addr_bits = HashMap::<(u64, usize), (usize, Range)>::new(); + let mut shared_selectors = HashMap::<(u64, usize), usize>::new(); let mut shout_cols = Vec::::new(); - for (ell_addr, lanes) in shout_ell_addrs_and_lanes { + for shape in shout_shapes { + let ell_addr = shape.ell_addr; + let lanes = shape.lanes; + let n_vals = shape.n_vals.max(1); let lanes = lanes.max(1); let mut lane_cols = Vec::::with_capacity(lanes); - for _lane in 0..lanes { - let addr_bits = col..(col + ell_addr); - col = col - .checked_add(ell_addr) - .ok_or_else(|| "BusLayout: column overflow (shout addr_bits)".to_string())?; - let has_lookup = col; - col = col - .checked_add(1) - .ok_or_else(|| "BusLayout: column overflow (shout has_lookup)".to_string())?; - let val = col; - col = col - .checked_add(1) - .ok_or_else(|| "BusLayout: column overflow (shout val)".to_string())?; + for lane_idx in 0..lanes { + let addr_bits = if let Some(group_id) = shape.addr_group { + let key = (group_id, lane_idx); + if let Some((prev_ell, prev_range)) = shared_addr_bits.get(&key) { + if *prev_ell != ell_addr { + return Err(format!( + "BusLayout: shared shout addr group mismatch for group_id={group_id}, lane={lane_idx} (prev ell_addr={}, new ell_addr={ell_addr})", + *prev_ell + )); + } + prev_range.clone() + } else { + let range = col..(col + ell_addr); + col = col + .checked_add(ell_addr) + .ok_or_else(|| "BusLayout: column overflow (shout shared addr_bits)".to_string())?; + shared_addr_bits.insert(key, (ell_addr, range.clone())); + range + } + } else { + let range = col..(col + ell_addr); + col = col + .checked_add(ell_addr) + .ok_or_else(|| "BusLayout: column overflow (shout addr_bits)".to_string())?; + range + }; + let has_lookup = if let Some(group_id) = shape.selector_group { + let key = (group_id, lane_idx); + if let Some(prev) = shared_selectors.get(&key) { + *prev + } else { + let out = col; + col = col + .checked_add(1) + .ok_or_else(|| "BusLayout: column overflow (shout has_lookup)".to_string())?; + shared_selectors.insert(key, out); + out + } + } else { + let out = col; + col = col + .checked_add(1) + .ok_or_else(|| "BusLayout: column overflow (shout has_lookup)".to_string())?; + out + }; + let mut vals = Vec::with_capacity(n_vals); + for _ in 0..n_vals { + vals.push(col); + col = col + .checked_add(1) + .ok_or_else(|| "BusLayout: column overflow (shout val)".to_string())?; + } lane_cols.push(ShoutCols { addr_bits, has_lookup, - val, + vals, }); } shout_cols.push(ShoutInstanceCols { lanes: lane_cols }); diff --git a/crates/neo-memory/src/cpu/constraints.rs b/crates/neo-memory/src/cpu/constraints.rs index 5d233637..73ce1c3f 100644 --- a/crates/neo-memory/src/cpu/constraints.rs +++ b/crates/neo-memory/src/cpu/constraints.rs @@ -41,7 +41,8 @@ use neo_ccs::{CcsMatrix, CscMat}; use p3_field::{Field, PrimeCharacteristicRing}; use crate::cpu::bus_layout::{ - build_bus_layout_for_instances_with_shout_and_twist_lanes, BusLayout, ShoutCols, TwistCols, + build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes, BusLayout, ShoutCols, ShoutInstanceShape, + TwistCols, }; use crate::witness::{LutInstance, MemInstance}; @@ -96,10 +97,20 @@ pub struct TwistCpuBinding { pub inc: Option, } +/// Sentinel column id used to disable a Twist CPU linkage field on a lane. +/// +/// This is only valid for `TwistCpuBinding` fields and is interpreted by the constraint builder +/// as "no CPU binding for this selector/value/address family". In that mode the lane is still +/// protected by canonical bus padding/bitness constraints. +pub const CPU_BUS_COL_DISABLED: usize = usize::MAX; + /// Per-instance CPU→bus binding for a Shout (lookup) bus slice. #[derive(Clone, Debug)] pub struct ShoutCpuBinding { /// CPU selector column for a lookup op (must equal bus `has_lookup`). + /// + /// `CPU_BUS_COL_DISABLED` disables selector/value linkage for this lane while still allowing + /// optional key binding through `addr` (conditioned by bus `has_lookup`). pub has_lookup: usize, /// Optional packed integer lookup key/address column. /// @@ -111,6 +122,8 @@ pub struct ShoutCpuBinding { /// set this to `None`. pub addr: Option, /// CPU lookup output/value column (must equal bus `val` when `has_lookup=1`). + /// + /// `CPU_BUS_COL_DISABLED` disables value linkage for this lane. pub val: usize, } @@ -392,62 +405,82 @@ impl CpuConstraintBuilder { let bus_inc = layout.bus_cell(twist.inc, j); // CPU columns are assumed to be chunked (contiguous, per-step): col(j) = col_base + j. - let cpu_has_read = cpu.has_read + j; - let cpu_has_write = cpu.has_write + j; - let cpu_read_addr = cpu.read_addr + j; - let cpu_write_addr = cpu.write_addr + j; - let cpu_rv = cpu.rv + j; - let cpu_wv = cpu.wv + j; - let cpu_inc = cpu.inc.map(|col| col + j); + let cpu_has_read = (cpu.has_read != CPU_BUS_COL_DISABLED).then_some(cpu.has_read + j); + let cpu_has_write = (cpu.has_write != CPU_BUS_COL_DISABLED).then_some(cpu.has_write + j); + let cpu_read_addr = (cpu.read_addr != CPU_BUS_COL_DISABLED).then_some(cpu.read_addr + j); + let cpu_write_addr = (cpu.write_addr != CPU_BUS_COL_DISABLED).then_some(cpu.write_addr + j); + let cpu_rv = (cpu.rv != CPU_BUS_COL_DISABLED).then_some(cpu.rv + j); + let cpu_wv = (cpu.wv != CPU_BUS_COL_DISABLED).then_some(cpu.wv + j); + let cpu_inc = cpu + .inc + .and_then(|col| (col != CPU_BUS_COL_DISABLED).then_some(col + j)); // Ensure bus selectors are boolean so gated-bit constraints imply true {0,1} bitness. self.add_boolean_constraint(CpuConstraintLabel::TwistHasReadBoolean, bus_has_read); self.add_boolean_constraint(CpuConstraintLabel::TwistHasWriteBoolean, bus_has_write); - // Value binding constraints - // has_read * (rv_cpu - bus_rv) = 0 - self.constraints.push(CpuConstraint::new_eq( - CpuConstraintLabel::LoadValueBinding, - cpu_has_read, - cpu_rv, - bus_rv, - )); - - // has_write * (wv_cpu - bus_wv) = 0 - self.constraints.push(CpuConstraint::new_eq( - CpuConstraintLabel::StoreValueBinding, - cpu_has_write, - cpu_wv, - bus_wv, - )); - - // Selector binding: cpu_has_* == bus_has_* - self.add_equality_constraint(CpuConstraintLabel::LoadSelectorBinding, cpu_has_read, bus_has_read); - self.add_equality_constraint(CpuConstraintLabel::StoreSelectorBinding, cpu_has_write, bus_has_write); - - // Address binding (bit-pack): - // - has_read * (read_addr - pack(ra_bits)) = 0 - // - has_write * (write_addr - pack(wa_bits)) = 0 - self.constraints.push(CpuConstraint::new_terms( - CpuConstraintLabel::LoadAddressBinding, - cpu_has_read, - false, - pack_addr_bits::(cpu_read_addr, twist.ra_bits.clone(), layout, j), - )); - self.constraints.push(CpuConstraint::new_terms( - CpuConstraintLabel::StoreAddressBinding, - cpu_has_write, - false, - pack_addr_bits::(cpu_write_addr, twist.wa_bits.clone(), layout, j), - )); + if let Some(cpu_has_read) = cpu_has_read { + let cpu_rv = cpu_rv.expect("Twist read binding requires rv column"); + let cpu_read_addr = cpu_read_addr.expect("Twist read binding requires read_addr column"); + // has_read * (rv_cpu - bus_rv) = 0 + self.constraints.push(CpuConstraint::new_eq( + CpuConstraintLabel::LoadValueBinding, + cpu_has_read, + cpu_rv, + bus_rv, + )); + // Selector binding: cpu_has_read == bus_has_read + self.add_equality_constraint(CpuConstraintLabel::LoadSelectorBinding, cpu_has_read, bus_has_read); + // has_read * (read_addr - pack(ra_bits)) = 0 + self.constraints.push(CpuConstraint::new_terms( + CpuConstraintLabel::LoadAddressBinding, + cpu_has_read, + false, + pack_addr_bits::(cpu_read_addr, twist.ra_bits.clone(), layout, j), + )); + } else { + // Explicitly force the read selector to 0 when the lane is unbound. + self.constraints.push(CpuConstraint::new_zero( + CpuConstraintLabel::LoadSelectorBinding, + self.const_one_col, + bus_has_read, + )); + } - // Optional: bind CPU increment semantics if provided. - if let Some(cpu_inc) = cpu_inc { + if let Some(cpu_has_write) = cpu_has_write { + let cpu_wv = cpu_wv.expect("Twist write binding requires wv column"); + let cpu_write_addr = cpu_write_addr.expect("Twist write binding requires write_addr column"); + // has_write * (wv_cpu - bus_wv) = 0 self.constraints.push(CpuConstraint::new_eq( - CpuConstraintLabel::IncrementBinding, + CpuConstraintLabel::StoreValueBinding, cpu_has_write, - cpu_inc, - bus_inc, + cpu_wv, + bus_wv, + )); + // Selector binding: cpu_has_write == bus_has_write + self.add_equality_constraint(CpuConstraintLabel::StoreSelectorBinding, cpu_has_write, bus_has_write); + // has_write * (write_addr - pack(wa_bits)) = 0 + self.constraints.push(CpuConstraint::new_terms( + CpuConstraintLabel::StoreAddressBinding, + cpu_has_write, + false, + pack_addr_bits::(cpu_write_addr, twist.wa_bits.clone(), layout, j), + )); + // Optional: bind CPU increment semantics if provided. + if let Some(cpu_inc) = cpu_inc { + self.constraints.push(CpuConstraint::new_eq( + CpuConstraintLabel::IncrementBinding, + cpu_has_write, + cpu_inc, + bus_inc, + )); + } + } else { + // Explicitly force the write selector to 0 when the lane is unbound. + self.constraints.push(CpuConstraint::new_zero( + CpuConstraintLabel::StoreSelectorBinding, + self.const_one_col, + bus_has_write, )); } @@ -473,28 +506,16 @@ impl CpuConstraintBuilder { )); // Read address bits: - // - Padding: (1 - has_read) * bit = 0 // - Bitness: bit is 0 when inactive, boolean when active for col_id in twist.ra_bits.clone() { let bit = layout.bus_cell(col_id, j); - self.constraints.push(CpuConstraint::new_zero_negated( - CpuConstraintLabel::ReadAddressBitsZeroPadding, - bus_has_read, - bit, - )); self.add_gated_bit_constraint(CpuConstraintLabel::TwistReadAddrBitBitness, bit, bus_has_read); } // Write address bits: - // - Padding: (1 - has_write) * bit = 0 // - Bitness: bit is 0 when inactive, boolean when active for col_id in twist.wa_bits.clone() { let bit = layout.bus_cell(col_id, j); - self.constraints.push(CpuConstraint::new_zero_negated( - CpuConstraintLabel::WriteAddressBitsZeroPadding, - bus_has_write, - bit, - )); self.add_gated_bit_constraint(CpuConstraintLabel::TwistWriteAddrBitBitness, bit, bus_has_write); } } @@ -524,43 +545,67 @@ impl CpuConstraintBuilder { /// Add constraints for a Shout (lookup) instance using an explicit per-instance CPU binding. pub fn add_shout_instance_bound(&mut self, layout: &BusLayout, shout: &ShoutCols, cpu: &ShoutCpuBinding) { + self.add_shout_instance_linkage_bound(layout, shout, cpu); + self.add_shout_instance_padding(layout, shout); + } + + /// Add Shout CPU linkage constraints only (no selector bitness / inactive padding). + pub fn add_shout_instance_linkage_bound(&mut self, layout: &BusLayout, shout: &ShoutCols, cpu: &ShoutCpuBinding) { for j in 0..layout.chunk_size { // Bus column indices let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); - let bus_val = layout.bus_cell(shout.val, j); + let bus_val = layout.bus_cell(shout.primary_val(), j); // CPU columns are assumed to be chunked (contiguous, per-step): col(j) = col_base + j. - let cpu_has_lookup = cpu.has_lookup + j; - let cpu_val = cpu.val + j; - - // Ensure bus selector is boolean so gated-bit constraints imply true {0,1} bitness. - self.add_boolean_constraint(CpuConstraintLabel::ShoutHasLookupBoolean, bus_has_lookup); + let cpu_has_lookup = (cpu.has_lookup != CPU_BUS_COL_DISABLED).then_some(cpu.has_lookup + j); + let cpu_val = (cpu.val != CPU_BUS_COL_DISABLED).then_some(cpu.val + j); - // Value binding: is_lookup * (lookup_output - bus_val) = 0 - self.constraints.push(CpuConstraint::new_eq( - CpuConstraintLabel::LookupValueBinding, - cpu_has_lookup, - cpu_val, - bus_val, - )); + if let (Some(cpu_has_lookup), Some(cpu_val)) = (cpu_has_lookup, cpu_val) { + // Value binding: is_lookup * (lookup_output - bus_val) = 0 + self.constraints.push(CpuConstraint::new_eq( + CpuConstraintLabel::LookupValueBinding, + cpu_has_lookup, + cpu_val, + bus_val, + )); - // Selector binding: cpu_has_lookup == bus_has_lookup - self.add_equality_constraint( - CpuConstraintLabel::LookupSelectorBinding, - cpu_has_lookup, - bus_has_lookup, - ); + // Selector binding: cpu_has_lookup == bus_has_lookup + self.add_equality_constraint( + CpuConstraintLabel::LookupSelectorBinding, + cpu_has_lookup, + bus_has_lookup, + ); + } else if cpu_has_lookup.is_some() || cpu_val.is_some() { + debug_assert!( + false, + "ShoutCpuBinding must set both has_lookup and val, or disable both with CPU_BUS_COL_DISABLED" + ); + } - // Optional key binding (bit-pack): is_lookup * (lookup_key - pack(addr_bits)) = 0 + // Optional key binding (bit-pack): is_lookup * (lookup_key - pack(addr_bits)) = 0. + // + // If CPU selector linkage is disabled for this lane, use bus `has_lookup` as the gate. if let Some(cpu_addr_base) = cpu.addr { let cpu_addr = cpu_addr_base + j; + let gate_col = cpu_has_lookup.unwrap_or(bus_has_lookup); self.constraints.push(CpuConstraint::new_terms( CpuConstraintLabel::LookupKeyBinding, - cpu_has_lookup, + gate_col, false, pack_addr_bits::(cpu_addr, shout.addr_bits.clone(), layout, j), )); } + } + } + + /// Add Shout selector/bitness/padding constraints only (no CPU linkage). + pub fn add_shout_instance_padding(&mut self, layout: &BusLayout, shout: &ShoutCols) { + for j in 0..layout.chunk_size { + let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); + let bus_val = layout.bus_cell(shout.primary_val(), j); + + // Ensure bus selector is boolean so gated-bit constraints imply true {0,1} bitness. + self.add_boolean_constraint(CpuConstraintLabel::ShoutHasLookupBoolean, bus_has_lookup); // Padding: (1 - has_lookup) * val = 0 self.constraints.push(CpuConstraint::new_zero_negated( @@ -570,20 +615,75 @@ impl CpuConstraintBuilder { )); // Lookup key bits: - // - Padding: (1 - has_lookup) * bit = 0 // - Bitness: bit is 0 when inactive, boolean when active for col_id in shout.addr_bits.clone() { let bit = layout.bus_cell(col_id, j); - self.constraints.push(CpuConstraint::new_zero_negated( - CpuConstraintLabel::LookupAddressBitsZeroPadding, - bus_has_lookup, - bit, - )); self.add_gated_bit_constraint(CpuConstraintLabel::ShoutAddrBitBitness, bit, bus_has_lookup); } } } + /// Add Shout addr-bit/value padding constraints without selector booleanity. + pub fn add_shout_instance_padding_without_selector_bitness(&mut self, layout: &BusLayout, shout: &ShoutCols) { + for j in 0..layout.chunk_size { + let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); + let bus_val = layout.bus_cell(shout.primary_val(), j); + + // Padding: (1 - has_lookup) * val = 0 + self.constraints.push(CpuConstraint::new_zero_negated( + CpuConstraintLabel::LookupValueZeroPadding, + bus_has_lookup, + bus_val, + )); + + // Lookup key bits: + // - Bitness: bit is 0 when inactive, boolean when active + for col_id in shout.addr_bits.clone() { + let bit = layout.bus_cell(col_id, j); + self.add_gated_bit_constraint(CpuConstraintLabel::ShoutAddrBitBitness, bit, bus_has_lookup); + } + } + } + + /// Add Shout selector/value padding only (no addr-bit constraints). + pub fn add_shout_instance_padding_value_only(&mut self, layout: &BusLayout, shout: &ShoutCols) { + for j in 0..layout.chunk_size { + let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); + let bus_val = layout.bus_cell(shout.primary_val(), j); + + self.add_boolean_constraint(CpuConstraintLabel::ShoutHasLookupBoolean, bus_has_lookup); + self.constraints.push(CpuConstraint::new_zero_negated( + CpuConstraintLabel::LookupValueZeroPadding, + bus_has_lookup, + bus_val, + )); + } + } + + /// Add Shout value padding only (no selector booleanity, no addr-bit constraints). + pub fn add_shout_instance_value_padding_only(&mut self, layout: &BusLayout, shout: &ShoutCols) { + for j in 0..layout.chunk_size { + let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); + let bus_val = layout.bus_cell(shout.primary_val(), j); + + self.constraints.push(CpuConstraint::new_zero_negated( + CpuConstraintLabel::LookupValueZeroPadding, + bus_has_lookup, + bus_val, + )); + } + } + + /// Add unconditional addr-bit booleanity constraints for one shared-address Shout group. + pub fn add_shout_instance_addr_bit_bitness(&mut self, layout: &BusLayout, shout: &ShoutCols) { + for j in 0..layout.chunk_size { + for col_id in shout.addr_bits.clone() { + let bit = layout.bus_cell(col_id, j); + self.add_boolean_constraint(CpuConstraintLabel::ShoutAddrBitBitness, bit); + } + } + } + /// Add an unconditional equality constraint: `left == right` (always). /// /// This is used for selector binding (is_load == has_read, etc.). @@ -677,7 +777,7 @@ impl CpuConstraintBuilder { let m = self.m; let num_constraints = self.constraints.len(); - // We need n >= num_constraints for square CCS + // We need enough rows to place all generated constraints. if num_constraints > n { return Err(format!( "too many constraints ({}) for CCS with n={}", @@ -718,40 +818,21 @@ impl CpuConstraintBuilder { let b = Mat::from_row_major(n, m, b_data); let c = Mat::from_row_major(n, m, c_data); - // Convert to CCS: f(x1, x2, x3) = x1 * x2 - x3 - // For identity-first CCS (square), we add I_n as M_0 - if n == m { - let i_n = Mat::identity(n); - let f = SparsePoly::new( - 4, - vec![ - Term { - coeff: F::ONE, - exps: vec![0, 1, 1, 0], // x1 * x2 - }, - Term { - coeff: -F::ONE, - exps: vec![0, 0, 0, 1], // -x3 - }, - ], - ); - CcsStructure::new(vec![i_n, a, b, c], f).map_err(|e| format!("failed to create CCS: {:?}", e)) - } else { - let f = SparsePoly::new( - 3, - vec![ - Term { - coeff: F::ONE, - exps: vec![1, 1, 0], // x1 * x2 - }, - Term { - coeff: -F::ONE, - exps: vec![0, 0, 1], // -x3 - }, - ], - ); - CcsStructure::new(vec![a, b, c], f).map_err(|e| format!("failed to create CCS: {:?}", e)) - } + // Convert to CCS: f(x0, x1, x2) = x0 * x1 - x2 + let f = SparsePoly::new( + 3, + vec![ + Term { + coeff: F::ONE, + exps: vec![1, 1, 0], // x0 * x1 + }, + Term { + coeff: -F::ONE, + exps: vec![0, 0, 1], // -x2 + }, + ], + ); + CcsStructure::new(vec![a, b, c], f).map_err(|e| format!("failed to create CCS: {:?}", e)) } /// Extend an existing CCS with bus binding constraints. @@ -886,6 +967,39 @@ pub fn extend_ccs_with_shared_cpu_bus_constraints], mem_insts: &[MemInstance], +) -> Result, String> { + let shout_cpu: Vec> = shout_cpu.iter().cloned().map(Some).collect(); + let empty_groups = std::collections::HashMap::new(); + extend_ccs_with_shared_cpu_bus_constraints_optional_shout( + base_ccs, + m_in, + const_one_col, + &shout_cpu, + twist_cpu, + lut_insts, + mem_insts, + &empty_groups, + &empty_groups, + ) +} + +/// Extend a CPU CCS with shared-bus constraints, allowing per-lane Shout linkage opt-out. +/// +/// When a Shout lane binding is `None`, only canonical Shout padding/bitness constraints are +/// injected for that lane (no CPU selector/value/key linkage equalities). +pub fn extend_ccs_with_shared_cpu_bus_constraints_optional_shout< + F: Field + PrimeCharacteristicRing + Copy + Eq + Send + Sync, + Cmt, +>( + base_ccs: &CcsStructure, + m_in: usize, + const_one_col: usize, + shout_cpu: &[Option], + twist_cpu: &[TwistCpuBinding], + lut_insts: &[LutInstance], + mem_insts: &[MemInstance], + shout_addr_groups: &std::collections::HashMap, + shout_selector_groups: &std::collections::HashMap, ) -> Result, String> { let total_shout_lanes: usize = lut_insts.iter().map(|l| l.lanes.max(1)).sum(); if shout_cpu.len() != total_shout_lanes { @@ -936,13 +1050,17 @@ pub fn extend_ccs_with_shared_cpu_bus_constraints::new(base_ccs.n, base_ccs.m, const_one_col); + let mut addr_range_counts = std::collections::HashMap::<(usize, usize), usize>::new(); + for inst_cols in layout.shout_cols.iter() { + for lane_cols in inst_cols.lanes.iter() { + let key = (lane_cols.addr_bits.start, lane_cols.addr_bits.end); + *addr_range_counts.entry(key).or_insert(0) += 1; + } + } + let mut addr_range_bitness_added = std::collections::HashSet::<(usize, usize)>::new(); + let mut selector_bitness_added = std::collections::HashSet::::new(); + let mut shout_key_binding_added = std::collections::HashSet::<(bool, usize, usize, usize, usize)>::new(); let mut shout_lane_idx = 0usize; for inst_cols in layout.shout_cols.iter() { for lane_cols in inst_cols.lanes.iter() { + let key = (lane_cols.addr_bits.start, lane_cols.addr_bits.end); + let shared_addr_group = addr_range_counts.get(&key).copied().unwrap_or(0) > 1; let cpu = shout_cpu .get(shout_lane_idx) .ok_or_else(|| format!("missing shout_cpu binding at lane_idx={shout_lane_idx}"))?; - builder.add_shout_instance_bound(&layout, lane_cols, cpu); + if let Some(cpu) = cpu { + let mut dedup_cpu = cpu.clone(); + if let Some(addr_base) = dedup_cpu.addr { + let (is_bus_gate, gate_base) = if dedup_cpu.has_lookup == CPU_BUS_COL_DISABLED { + (true, lane_cols.has_lookup) + } else { + (false, dedup_cpu.has_lookup) + }; + let key_sig = ( + is_bus_gate, + gate_base, + addr_base, + lane_cols.addr_bits.start, + lane_cols.addr_bits.end, + ); + if !shout_key_binding_added.insert(key_sig) { + dedup_cpu.addr = None; + } + } + builder.add_shout_instance_linkage_bound(&layout, lane_cols, &dedup_cpu); + } else { + // no linkage + } + let selector_first = selector_bitness_added.insert(lane_cols.has_lookup); + if shared_addr_group { + // Shared address bits across multiple table instances cannot use per-instance + // `(1-has_lookup)*addr_bit=0` gating: inactive instances would overconstrain + // active ones. Enforce has/value padding per-instance and addr-bit booleanity + // once per shared range. + if selector_first { + builder.add_shout_instance_padding_value_only(&layout, lane_cols); + } else { + builder.add_shout_instance_value_padding_only(&layout, lane_cols); + } + if addr_range_bitness_added.insert(key) { + builder.add_shout_instance_addr_bit_bitness(&layout, lane_cols); + } + } else { + if selector_first { + builder.add_shout_instance_padding(&layout, lane_cols); + } else { + builder.add_shout_instance_padding_without_selector_bitness(&layout, lane_cols); + } + } shout_lane_idx += 1; } } @@ -1095,7 +1268,7 @@ pub fn create_shout_padding_constraints(layout: &BusLayout, shout: &Sh let mut constraints = Vec::new(); for j in 0..layout.chunk_size { let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); - let bus_val = layout.bus_cell(shout.val, j); + let bus_val = layout.bus_cell(shout.primary_val(), j); // (1 - has_lookup) * val = 0 constraints.push(CpuConstraint::new_zero_negated( diff --git a/crates/neo-memory/src/cpu/r1cs_adapter.rs b/crates/neo-memory/src/cpu/r1cs_adapter.rs index bf321513..144483af 100644 --- a/crates/neo-memory/src/cpu/r1cs_adapter.rs +++ b/crates/neo-memory/src/cpu/r1cs_adapter.rs @@ -5,8 +5,12 @@ use crate::addr::write_addr_bits_dim_major_le_into_bus; use crate::builder::CpuArithmetization; -use crate::cpu::bus_layout::{build_bus_layout_for_instances_with_shout_and_twist_lanes, BusLayout}; -use crate::cpu::constraints::{extend_ccs_with_shared_cpu_bus_constraints, ShoutCpuBinding, TwistCpuBinding}; +use crate::cpu::bus_layout::{ + build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes, BusLayout, ShoutInstanceShape, +}; +use crate::cpu::constraints::{ + extend_ccs_with_shared_cpu_bus_constraints_optional_shout, ShoutCpuBinding, TwistCpuBinding, CPU_BUS_COL_DISABLED, +}; use crate::mem_init::MemInit; use crate::plain::LutTable; use crate::plain::PlainMemLayout; @@ -41,8 +45,9 @@ pub struct SharedCpuBusConfig { /// /// The bus tail contains one Shout instance per `table_id` known to this CPU (from `tables` in `R1csCpu::new`). /// - /// Each Shout instance may have multiple lookup lanes; this map must provide one `ShoutCpuBinding` - /// per lane in lane-index order. + /// Each Shout instance may have multiple lookup lanes. + /// - Non-empty vector: CPU linkage bindings in lane-index order. + /// - Empty vector: no CPU linkage for that table's bus lanes (padding/bitness only). pub shout_cpu: HashMap>, /// Per-memory CPU→bus bindings (twist_id -> binding). /// @@ -51,6 +56,16 @@ pub struct SharedCpuBusConfig { /// Each Twist instance may have multiple access lanes (`PlainMemLayout.lanes`); this map must /// provide one `TwistCpuBinding` per lane in lane-index order. pub twist_cpu: HashMap>, + /// Optional per-table address-sharing group ids (table_id -> group_id). + /// + /// Tables with the same group_id share `addr_bits` columns in the bus layout. + /// Leave empty when no families share address bits. Populated by trace mode for column efficiency. + pub shout_addr_groups: HashMap, + /// Optional per-table selector-sharing group ids (table_id -> group_id). + /// + /// Tables with the same group_id share `has_lookup` columns in the bus layout. + /// Leave empty when no families share selectors. Populated by trace mode for column efficiency. + pub shout_selector_groups: HashMap, } #[derive(Clone, Debug)] @@ -108,7 +123,7 @@ where tables: &HashMap>, table_specs: &HashMap, chunk_to_witness: Box]) -> Vec + Send + Sync>, - ) -> Self { + ) -> Result { let mut shout_meta = HashMap::new(); for (id, table) in tables { shout_meta.insert(*id, (table.d, table.n_side)); @@ -116,6 +131,12 @@ where for (id, spec) in table_specs { let (d, n_side) = match spec { LutTableSpec::RiscvOpcode { xlen, .. } => (xlen.saturating_mul(2), 2usize), + LutTableSpec::RiscvOpcodePacked { .. } => { + return Err("RiscvOpcodePacked is not supported in the shared-bus R1csCpu path".into()); + } + LutTableSpec::RiscvOpcodeEventTablePacked { .. } => { + return Err("RiscvOpcodeEventTablePacked is not supported in the shared-bus R1csCpu path".into()); + } LutTableSpec::IdentityU32 => (32usize, 2usize), }; match shout_meta.entry(*id) { @@ -135,7 +156,7 @@ where } } - Self { + Ok(Self { ccs, params, committer, @@ -144,7 +165,7 @@ where shared_cpu_bus: None, chunk_to_witness, _phantom: PhantomData, - } + }) } fn shared_bus_schema( @@ -157,7 +178,7 @@ where let mut mem_ids: Vec = bus.mem_layouts.keys().copied().collect(); mem_ids.sort_unstable(); - let mut shout_ell_addrs_and_lanes = Vec::with_capacity(table_ids.len()); + let mut shout_shapes = Vec::with_capacity(table_ids.len()); for table_id in &table_ids { let (d, n_side) = self .shout_meta @@ -174,14 +195,16 @@ where let lanes = bus .shout_cpu .get(table_id) - .ok_or_else(|| format!("shared_cpu_bus: missing shout_cpu binding for table_id={table_id}"))? - .len(); - if lanes == 0 { - return Err(format!( - "shared_cpu_bus: shout_cpu bindings for table_id={table_id} must be non-empty" - )); - } - shout_ell_addrs_and_lanes.push((ell_addr, lanes)); + .map(|v| v.len()) + .unwrap_or(0) + .max(1); + shout_shapes.push(ShoutInstanceShape { + ell_addr, + lanes, + n_vals: 1usize, + addr_group: bus.shout_addr_groups.get(table_id).copied(), + selector_group: bus.shout_selector_groups.get(table_id).copied(), + }); } let mut twist_ell_addrs_and_lanes = Vec::with_capacity(mem_ids.len()); @@ -201,11 +224,11 @@ where twist_ell_addrs_and_lanes.push((ell_addr, layout.lanes.max(1))); } - let layout = build_bus_layout_for_instances_with_shout_and_twist_lanes( + let layout = build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes( self.ccs.m, self.m_in, chunk_size, - shout_ell_addrs_and_lanes, + shout_shapes, twist_ell_addrs_and_lanes, )?; Ok((table_ids, mem_ids, layout)) @@ -244,16 +267,19 @@ where id: u32, bus_base: usize, chunk_size: usize, - cols: &[(&str, usize)], + cols: &[(&str, usize, bool)], ) -> Result<(), String> { let max_step_offset = chunk_size .checked_sub(1) .ok_or_else(|| "shared_cpu_bus: chunk_size must be >= 1".to_string())?; - for (label, col) in cols { + for (label, col, allow_bus_overlap) in cols { + if *col == CPU_BUS_COL_DISABLED { + continue; + } let max_col = col .checked_add(max_step_offset) .ok_or_else(|| format!("shared_cpu_bus: {kind} binding for id={id} overflows usize"))?; - if max_col >= bus_base { + if !allow_bus_overlap && max_col >= bus_base { return Err(format!( "shared_cpu_bus: {kind} binding for id={id} uses {label}={col} (max={max_col}), but bus_base={bus_base} (CPU bindings must be < bus_base to avoid overlapping the bus tail)" )); @@ -265,23 +291,23 @@ where // Build per-lane binding vectors in canonical order (id-sorted, then lane index). let total_shout_lanes: usize = table_ids .iter() - .map(|id| cfg.shout_cpu.get(id).map(|v| v.len()).unwrap_or(0)) + .map(|id| cfg.shout_cpu.get(id).map(|v| v.len().max(1)).unwrap_or(1)) .sum(); - let mut shout_cpu: Vec = Vec::with_capacity(total_shout_lanes); + let mut shout_cpu: Vec> = Vec::with_capacity(total_shout_lanes); for table_id in &table_ids { let bindings = cfg .shout_cpu .get(table_id) - .ok_or_else(|| format!("shared_cpu_bus: missing shout_cpu binding for table_id={table_id}"))?; + .map(Vec::as_slice) + .unwrap_or(&[]); if bindings.is_empty() { - return Err(format!( - "shared_cpu_bus: shout_cpu bindings for table_id={table_id} must be non-empty" - )); + shout_cpu.push(None); + continue; } for (lane_idx, b) in bindings.iter().enumerate() { - let mut cols = vec![("has_lookup", b.has_lookup), ("val", b.val)]; + let mut cols = vec![("has_lookup", b.has_lookup, false), ("val", b.val, false)]; if let Some(addr) = b.addr { - cols.push(("addr", addr)); + cols.push(("addr", addr, false)); } validate_cpu_binding_cols( &format!("shout_cpu[lane={lane_idx}]"), @@ -290,7 +316,7 @@ where chunk_size, &cols, )?; - shout_cpu.push(b.clone()); + shout_cpu.push(Some(b.clone())); } } let total_twist_lanes: usize = mem_ids @@ -321,15 +347,16 @@ where } for (lane_idx, b) in bindings.iter().enumerate() { let mut cols = vec![ - ("has_read", b.has_read), - ("has_write", b.has_write), - ("read_addr", b.read_addr), - ("write_addr", b.write_addr), - ("rv", b.rv), - ("wv", b.wv), + // Selector columns may intentionally source from bus-tail decode lookups. + ("has_read", b.has_read, true), + ("has_write", b.has_write, true), + ("read_addr", b.read_addr, false), + ("write_addr", b.write_addr, false), + ("rv", b.rv, false), + ("wv", b.wv, false), ]; if let Some(inc) = b.inc { - cols.push(("inc", inc)); + cols.push(("inc", inc, false)); } let kind = format!("twist_cpu[lane={lane_idx}]"); validate_cpu_binding_cols(&kind, *mem_id, bus_base, chunk_size, &cols)?; @@ -365,10 +392,11 @@ where let lanes = cfg .shout_cpu .get(table_id) - .ok_or_else(|| format!("shared_cpu_bus: missing shout_cpu binding for table_id={table_id}"))? - .len() + .map(|v| v.len()) + .unwrap_or(0) .max(1); lut_insts.push(LutInstance { + table_id: *table_id, comms: Vec::new(), k: 0, d, @@ -378,6 +406,8 @@ where ell, table_spec: None, table: Vec::new(), + addr_group: cfg.shout_addr_groups.get(table_id).copied(), + selector_group: cfg.shout_selector_groups.get(table_id).copied(), }); } @@ -389,6 +419,7 @@ where .ok_or_else(|| format!("shared_cpu_bus: missing mem_layout for mem_id={mem_id}"))?; let ell = layout.n_side.trailing_zeros() as usize; mem_insts.push(MemInstance { + mem_id: *mem_id, comms: Vec::new(), k: layout.k, d: layout.d, @@ -400,7 +431,7 @@ where }); } - self.ccs = extend_ccs_with_shared_cpu_bus_constraints( + self.ccs = extend_ccs_with_shared_cpu_bus_constraints_optional_shout( &self.ccs, self.m_in, cfg.const_one_col, @@ -408,6 +439,8 @@ where &twist_cpu, &lut_insts, &mem_insts, + &cfg.shout_addr_groups, + &cfg.shout_selector_groups, ) .map_err(|e| format!("shared_cpu_bus: failed to inject constraints: {e}"))?; @@ -662,7 +695,7 @@ where } } - // Shout lanes: addr_bits, has_lookup, val. + // Shout lanes: addr_bits, has_lookup, vals[0]. for (i, table_id) in shared.table_ids.iter().enumerate() { let inst_cols = &shared.layout.shout_cols[i]; let (d, n_side) = self @@ -685,7 +718,7 @@ where ell, ); z_vec[shared.layout.bus_cell(shout_cols.has_lookup, j)] = Goldilocks::ONE; - z_vec[shared.layout.bus_cell(shout_cols.val, j)] = val; + z_vec[shared.layout.bus_cell(shout_cols.primary_val(), j)] = val; } } } @@ -821,5 +854,26 @@ where ) -> Result, McsWitness)>, Self::Error> { self.build_ccs_chunks(trace, 1) } + + fn shout_addr_groups(&self) -> &HashMap { + self.shared_cpu_bus + .as_ref() + .map(|s| &s.cfg.shout_addr_groups) + .unwrap_or_else(|| { + static EMPTY: std::sync::LazyLock> = std::sync::LazyLock::new(HashMap::new); + &EMPTY + }) + } + + fn shout_selector_groups(&self) -> &HashMap { + self.shared_cpu_bus + .as_ref() + .map(|s| &s.cfg.shout_selector_groups) + .unwrap_or_else(|| { + static EMPTY: std::sync::LazyLock> = std::sync::LazyLock::new(HashMap::new); + &EMPTY + }) + } + type Error = String; } diff --git a/crates/neo-memory/src/lib.rs b/crates/neo-memory/src/lib.rs index ed30582c..87b3b540 100644 --- a/crates/neo-memory/src/lib.rs +++ b/crates/neo-memory/src/lib.rs @@ -8,7 +8,7 @@ //! //! # RISC-V Support //! -//! The current proving integration is RV32-focused (e.g. the shared-bus RV32 B1 path assumes +//! The current proving integration is RV32-focused (e.g. the shared-bus RV32 trace path assumes //! `xlen == 32`, no compressed instructions, and 4-byte aligned control flow). //! RV64 proving is not yet supported by the Shout key encoding used in this path. //! @@ -33,6 +33,7 @@ pub mod output_check; pub mod plain; pub mod riscv; pub mod shout; +pub mod sparse_matrix; pub mod sparse_time; pub mod sumcheck_proof; pub mod ts_common; diff --git a/crates/neo-memory/src/riscv/ccs.rs b/crates/neo-memory/src/riscv/ccs.rs index 81e8a52e..a04bd47f 100644 --- a/crates/neo-memory/src/riscv/ccs.rs +++ b/crates/neo-memory/src/riscv/ccs.rs @@ -1,89 +1,34 @@ -//! RV32 "B1" RISC-V step CCS (shared-bus compatible). +//! RV32 trace-wiring CCS (shared-bus compatible). //! -//! This module provides a **sound, shared-bus-compatible** step circuit for a small, -//! MVP RV32 subset. The circuit is expressed as an R1CS→CCS: -//! - `A(z) * B(z) = C(z)` with `C = 0` for almost all rows -//! - when `n == m`, we include an identity-first `M0 = I_n` to match `neo_ccs::r1cs_to_ccs` -//! -//! The witness `z` includes a **reserved bus tail** whose column schema matches -//! `cpu::bus_layout::BusLayout`. The bus tail itself is written from `StepTrace` -//! events by `R1csCpu` (shared-bus mode), and is verified by the Twist/Shout sidecars. -//! -//! ## Execution model (Phase 1) -//! - Each lane `j` in a chunk is one architectural step. -//! - `is_active[j] ∈ {0,1}` gates padding; inactive lanes keep `(pc, regs)` constant and perform -//! no bus activity (enforced via shared-bus padding constraints). -//! - Intra-chunk continuity is enforced: `pc_in[j+1] = pc_out[j]` and `regs_in[j+1] = regs_out[j]`. -//! - The chunk exposes public boundary state (`pc0/regs0` at lane 0 and `pc_final/regs_final` at the -//! last lane). Multi-chunk executions must chain these boundary values across chunks at a higher layer. -//! -//! The CCS here constrains the **CPU glue**: -//! - ROM fetch binding (`PROG_ID`) via shared-bus bindings -//! - instruction decode from a committed 32-bit instruction word -//! - register-file update pattern -//! - RAM load/store binding (`RAM_ID`) via shared-bus bindings -//! - Shout key wiring for `ADD` lookups (table id 3) -//! -//! Supported RV32IMA subset (RV32, word-only memory, no compressed): -//! - ALU (R-type): `ADD`, `SUB`, `SLL`, `SLT`, `SLTU`, `XOR`, `SRL`, `SRA`, `OR`, `AND` -//! - M (R-type, in-circuit): `MUL`, `MULH`, `MULHU`, `MULHSU`, `DIV`, `DIVU`, `REM`, `REMU` -//! - ALU (I-type): `ADDI`, `SLTI`, `SLTIU`, `XORI`, `ORI`, `ANDI`, `SLLI`, `SRLI`, `SRAI` -//! - Memory (byte/half/word): `LB`, `LBU`, `LH`, `LHU`, `LW`, `SB`, `SH`, `SW` -//! - Atomics (word): `AMOADD.W`, `AMOAND.W`, `AMOOR.W`, `AMOXOR.W`, `AMOSWAP.W` -//! - Branch: `BEQ`, `BNE`, `BLT`, `BGE`, `BLTU`, `BGEU` -//! - Jump: `JAL`, `JALR` -//! - U-type: `LUI`, `AUIPC` -//! - System: `FENCE`, `ECALL(imm=0)` (halts unless a0 is a Jolt marker/print id) - -use std::collections::HashMap; - -use neo_ccs::relations::CcsStructure; -use p3_field::PrimeCharacteristicRing; -use p3_goldilocks::Goldilocks as F; - -use crate::plain::PlainMemLayout; -use crate::riscv::lookups::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID}; +//! This module exposes the trace-mode CCS layout/witness/builders and shared CPU-bus +//! requirements/configuration used by the `Rv32TraceWiring` proving flow. mod bus_bindings; -mod config; mod constants; mod constraint_builder; -mod layout; -mod witness; +mod trace; -pub use bus_bindings::rv32_b1_shared_cpu_bus_config; -pub use layout::Rv32B1Layout; -pub use witness::{ - rv32_b1_chunk_to_full_witness, rv32_b1_chunk_to_full_witness_checked, rv32_b1_chunk_to_witness, - rv32_b1_chunk_to_witness_checked, +pub use bus_bindings::{ + rv32_trace_shared_bus_extraction, rv32_trace_shared_bus_extraction_with_specs, rv32_trace_shared_bus_requirements, + rv32_trace_shared_bus_requirements_with_specs, rv32_trace_shared_cpu_bus_config, + rv32_trace_shared_cpu_bus_config_with_specs, TraceSharedBusExtraction, TraceShoutBusSpec, +}; +pub use trace::{ + build_rv32_trace_wiring_ccs, build_rv32_trace_wiring_ccs_with_reserved_rows, + rv32_trace_ccs_witness_from_exec_table, rv32_trace_ccs_witness_from_trace_witness, Rv32TraceCcsLayout, }; -/// Verifier-side step-linking pairs for chaining multi-chunk executions. -/// -/// For each adjacent pair of shard steps (chunks) `i` and `i+1`, require: -/// - `pc_final[i] == pc0[i+1]` -/// - `regs_final[i][r] == regs0[i+1][r]` for `r ∈ [0..32)` -/// - `halted_out[i] == halted_in[i+1]` -/// -/// This is the minimal glue needed to make `chunk_size` a semantic no-op: the CPU state must form -/// one contiguous execution across chunks. -pub fn rv32_b1_step_linking_pairs(layout: &Rv32B1Layout) -> Vec<(usize, usize)> { - let mut pairs = Vec::with_capacity(34); - pairs.push((layout.pc_final, layout.pc0)); - for r in 0..32 { - pairs.push((layout.regs_final_start + r, layout.regs0_start + r)); - } - pairs.push((layout.halted_out, layout.halted_in)); - pairs -} +use constants::{ + ADD_TABLE_ID, AND_TABLE_ID, DIVU_TABLE_ID, DIV_TABLE_ID, EQ_TABLE_ID, MULHSU_TABLE_ID, MULHU_TABLE_ID, + MULH_TABLE_ID, MUL_TABLE_ID, NEQ_TABLE_ID, OR_TABLE_ID, REMU_TABLE_ID, REM_TABLE_ID, SLL_TABLE_ID, SLTU_TABLE_ID, + SLT_TABLE_ID, SRA_TABLE_ID, SRL_TABLE_ID, SUB_TABLE_ID, XOR_TABLE_ID, +}; -/// Minimal Shout table set intended for small RV32 programs that only need: -/// - `ADD` (address/ALU wiring), and -/// - `EQ`/`NEQ` (BEQ/BNE). -pub const RV32_B1_SHOUT_PROFILE_MIN3: &[u32] = &[ADD_TABLE_ID, EQ_TABLE_ID, NEQ_TABLE_ID]; +/// Minimal trace-mode Shout profile for tiny RV32 programs. +pub const RV32_TRACE_SHOUT_PROFILE_MIN3: &[u32] = &[ADD_TABLE_ID, EQ_TABLE_ID, NEQ_TABLE_ID]; -/// Full RV32I Shout table set (ids 0..=11). -pub const RV32_B1_SHOUT_PROFILE_FULL12: &[u32] = &[ +/// Full RV32I trace-mode Shout profile. +pub const RV32_TRACE_SHOUT_PROFILE_FULL12: &[u32] = &[ AND_TABLE_ID, XOR_TABLE_ID, OR_TABLE_ID, @@ -98,9 +43,8 @@ pub const RV32_B1_SHOUT_PROFILE_FULL12: &[u32] = &[ NEQ_TABLE_ID, ]; -/// Full RV32IM Shout table set (ids 0..=19). -/// M tables are optional; RV32 B1 proves M ops in-circuit and ignores their Shout lanes. -pub const RV32_B1_SHOUT_PROFILE_FULL20: &[u32] = &[ +/// Full RV32IM trace-mode Shout profile. +pub const RV32_TRACE_SHOUT_PROFILE_FULL20: &[u32] = &[ AND_TABLE_ID, XOR_TABLE_ID, OR_TABLE_ID, @@ -122,2619 +66,3 @@ pub const RV32_B1_SHOUT_PROFILE_FULL20: &[u32] = &[ REM_TABLE_ID, REMU_TABLE_ID, ]; - -use bus_bindings::injected_bus_constraints_len; -use config::{derive_mem_ids_and_ell_addrs, derive_shout_ids_and_ell_addrs}; -use constraint_builder::{build_r1cs_ccs, Constraint}; -use layout::build_layout_with_m; - -use constants::{ - ADD_TABLE_ID, AND_TABLE_ID, DIVU_TABLE_ID, DIV_TABLE_ID, EQ_TABLE_ID, MULHSU_TABLE_ID, MULHU_TABLE_ID, - MULH_TABLE_ID, MUL_TABLE_ID, NEQ_TABLE_ID, OR_TABLE_ID, REMU_TABLE_ID, REM_TABLE_ID, RV32_XLEN, SLL_TABLE_ID, - SLTU_TABLE_ID, SLT_TABLE_ID, SRA_TABLE_ID, SRL_TABLE_ID, SUB_TABLE_ID, XOR_TABLE_ID, -}; - -fn pow2_u64(i: usize) -> u64 { - 1u64 << i -} - -fn enforce_u32_bits( - constraints: &mut Vec>, - one: usize, - value_col: usize, - bits_start: usize, - chunk_size: usize, - j: usize, -) { - // bit_i * (1 - bit_i) = 0 for each bit. - for bit in 0..32 { - let b = bits_start + bit * chunk_size + j; - constraints.push(Constraint { - condition_col: b, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (b, -F::ONE)], - c_terms: Vec::new(), - }); - } - - // value = sum_i 2^i * bit_i - let mut terms = vec![(value_col, F::ONE)]; - for bit in 0..32 { - let b = bits_start + bit * chunk_size + j; - terms.push((b, -F::from_u64(pow2_u64(bit)))); - } - constraints.push(Constraint::terms(one, false, terms)); -} - -fn semantic_constraints( - layout: &Rv32B1Layout, - mem_layouts: &HashMap, -) -> Result>, String> { - let one = layout.const_one; - - let mut constraints = Vec::>::new(); - - let shout_cols = |table_id: u32| { - layout - .table_ids - .binary_search(&table_id) - .ok() - .map(|idx| &layout.bus.shout_cols[idx].lanes[0]) - }; - - // The ADD table is required because this circuit uses it for address/ALU wiring (LW/SW/AUIPC/JALR). - let add_shout_idx = layout.shout_idx(ADD_TABLE_ID)?; - let add_cols = &layout.bus.shout_cols[add_shout_idx].lanes[0]; - - let and_cols = shout_cols(AND_TABLE_ID); - let xor_cols = shout_cols(XOR_TABLE_ID); - let or_cols = shout_cols(OR_TABLE_ID); - let sub_cols = shout_cols(SUB_TABLE_ID); - let slt_cols = shout_cols(SLT_TABLE_ID); - let sltu_cols = shout_cols(SLTU_TABLE_ID); - let sll_cols = shout_cols(SLL_TABLE_ID); - let srl_cols = shout_cols(SRL_TABLE_ID); - let sra_cols = shout_cols(SRA_TABLE_ID); - let eq_cols = shout_cols(EQ_TABLE_ID); - let neq_cols = shout_cols(NEQ_TABLE_ID); - let mul_cols = shout_cols(MUL_TABLE_ID); - let mulh_cols = shout_cols(MULH_TABLE_ID); - let mulhu_cols = shout_cols(MULHU_TABLE_ID); - let mulhsu_cols = shout_cols(MULHSU_TABLE_ID); - let div_cols = shout_cols(DIV_TABLE_ID); - let divu_cols = shout_cols(DIVU_TABLE_ID); - let rem_cols = shout_cols(REM_TABLE_ID); - let remu_cols = shout_cols(REMU_TABLE_ID); - - let ell_addr = 2 * RV32_XLEN; - for (name, cols_opt) in [ - ("ADD", Some(add_cols)), - ("AND", and_cols), - ("XOR", xor_cols), - ("OR", or_cols), - ("SUB", sub_cols), - ("SLT", slt_cols), - ("SLTU", sltu_cols), - ("SLL", sll_cols), - ("SRL", srl_cols), - ("SRA", sra_cols), - ("EQ", eq_cols), - ("NEQ", neq_cols), - ("MUL", mul_cols), - ("MULH", mulh_cols), - ("MULHU", mulhu_cols), - ("MULHSU", mulhsu_cols), - ("DIV", div_cols), - ("DIVU", divu_cols), - ("REM", rem_cols), - ("REMU", remu_cols), - ] { - if let Some(cols) = cols_opt { - if cols.addr_bits.end - cols.addr_bits.start != ell_addr { - return Err(format!( - "{name} shout bus layout mismatch: expected ell_addr={ell_addr}, got {}", - cols.addr_bits.end - cols.addr_bits.start - )); - } - } - } - - // If a Shout table isn't included, forbid the corresponding instruction variants. - if and_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_and(j))); - constraints.push(Constraint::zero(one, layout.is_andi(j))); - } - } - if or_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_or(j))); - constraints.push(Constraint::zero(one, layout.is_ori(j))); - } - } - if xor_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_xor(j))); - constraints.push(Constraint::zero(one, layout.is_xori(j))); - } - } - if sub_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_sub(j))); - } - } - if sll_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_sll(j))); - constraints.push(Constraint::zero(one, layout.is_slli(j))); - } - } - if srl_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_srl(j))); - constraints.push(Constraint::zero(one, layout.is_srli(j))); - } - } - if sra_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_sra(j))); - constraints.push(Constraint::zero(one, layout.is_srai(j))); - } - } - if slt_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_slt(j))); - constraints.push(Constraint::zero(one, layout.is_slti(j))); - constraints.push(Constraint::zero(one, layout.is_blt(j))); - constraints.push(Constraint::zero(one, layout.is_bge(j))); - } - } - if sltu_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_sltu(j))); - constraints.push(Constraint::zero(one, layout.is_sltiu(j))); - constraints.push(Constraint::zero(one, layout.is_bltu(j))); - constraints.push(Constraint::zero(one, layout.is_bgeu(j))); - // DIVU/REMU need SLTU to prove `rem < divisor` when divisor != 0. - constraints.push(Constraint::zero(one, layout.is_divu(j))); - constraints.push(Constraint::zero(one, layout.is_remu(j))); - // DIV/REM need SLTU for the signed remainder bound check. - constraints.push(Constraint::zero(one, layout.is_div(j))); - constraints.push(Constraint::zero(one, layout.is_rem(j))); - } - } - if eq_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_beq(j))); - } - } - if neq_cols.is_none() { - for j in 0..layout.chunk_size { - constraints.push(Constraint::zero(one, layout.is_bne(j))); - } - } - let _ = (mulh_cols, mulhu_cols, mulhsu_cols, div_cols, rem_cols); - - // Alignment constraints require bit-addressed memories (n_side=2). - let prog_id = PROG_ID.0; - let prog_layout = mem_layouts - .get(&prog_id) - .ok_or_else(|| format!("mem_layouts missing PROG_ID={prog_id}"))?; - if prog_layout.n_side != 2 || prog_layout.d < 2 { - return Err("RV32 B1: PROG_ID must use n_side=2 and d>=2 (bit addressing)".into()); - } - let ram_id = RAM_ID.0; - let ram_layout = mem_layouts - .get(&ram_id) - .ok_or_else(|| format!("mem_layouts missing RAM_ID={ram_id}"))?; - if ram_layout.n_side != 2 || ram_layout.d < 2 { - return Err("RV32 B1: RAM_ID must use n_side=2 and d>=2 (bit addressing)".into()); - } - - let prog = &layout.bus.twist_cols[layout.prog_twist_idx].lanes[0]; - let ram = &layout.bus.twist_cols[layout.ram_twist_idx].lanes[0]; - - let pack_interleaved_operand = - |addr_bits_start: usize, j: usize, parity: usize, value_col: usize| -> Vec<(usize, F)> { - debug_assert!(parity == 0 || parity == 1, "parity must be 0 (even) or 1 (odd)"); - let mut terms = vec![(value_col, F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = addr_bits_start + 2 * i + parity; - let bit = layout.bus.bus_cell(bit_col_id, j); - terms.push((bit, -F::from_u64(pow2_u64(i)))); - } - terms - }; - - // --- Public I/O binding (initial + final architectural state) --- - // Initial state binds to lane 0. - let j0 = 0usize; - constraints.push(Constraint::terms( - one, - false, - vec![(layout.pc_in(j0), F::ONE), (layout.pc0, -F::ONE)], - )); - for r in 0..32 { - constraints.push(Constraint::terms( - one, - false, - vec![(layout.reg_in(r, j0), F::ONE), (layout.regs0_start + r, -F::ONE)], - )); - } - - // Final state binds to the last lane. - let j_last = layout.chunk_size - 1; - constraints.push(Constraint::terms( - one, - false, - vec![(layout.pc_out(j_last), F::ONE), (layout.pc_final, -F::ONE)], - )); - for r in 0..32 { - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.reg_out(r, j_last), F::ONE), - (layout.regs_final_start + r, -F::ONE), - ], - )); - } - - // --- Cross-chunk halting / padding semantics (L1-style) --- - // halted_in/out are booleans. - constraints.push(Constraint::terms( - layout.halted_in, - false, - vec![(layout.halted_in, F::ONE), (one, -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.halted_out, - false, - vec![(layout.halted_out, F::ONE), (one, -F::ONE)], - )); - - // halted_in + is_active[0] = 1 (chunk starts active iff not halted). - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.halted_in, F::ONE), - (layout.is_active(j0), F::ONE), - (one, -F::ONE), - ], - )); - - // halted_out = 1 - is_active[last] + halt_effective[last]. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.halted_out, F::ONE), - (layout.is_active(j_last), F::ONE), - (layout.halt_effective(j_last), -F::ONE), - (one, -F::ONE), - ], - )); - - for j in 0..layout.chunk_size { - let is_active = layout.is_active(j); - let pc_in = layout.pc_in(j); - let pc_out = layout.pc_out(j); - let instr_word = layout.instr_word(j); - let add_a0 = layout.bus.bus_cell(add_cols.addr_bits.start + 0, j); - let add_b0 = layout.bus.bus_cell(add_cols.addr_bits.start + 1, j); - - // x0 hardwired. - constraints.push(Constraint::zero(one, layout.reg_in(0, j))); - constraints.push(Constraint::zero(one, layout.reg_out(0, j))); - - // is_active is boolean. - constraints.push(Constraint::terms( - is_active, - false, - vec![(is_active, F::ONE), (one, -F::ONE)], - )); - - // Inactive rows keep PC constant: (1 - is_active) * (pc_out - pc_in) = 0. - constraints.push(Constraint::terms( - is_active, - true, - vec![(pc_out, F::ONE), (pc_in, -F::ONE)], - )); - - // Instruction bits: - // - If is_active=0, force all bits to 0. - // - If is_active=1, force bits to be boolean. - for i in 0..32 { - let b = layout.instr_bit(i, j); - constraints.push(Constraint::terms(b, false, vec![(b, F::ONE), (is_active, -F::ONE)])); - } - - // Pack instr_word = Σ 2^i bit[i] - { - let mut terms = vec![(instr_word, F::ONE)]; - for i in 0..32 { - terms.push((layout.instr_bit(i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // Pack opcode/funct/fields from bits. - { - // opcode = bits[0..6] - let mut terms = vec![(layout.opcode(j), F::ONE)]; - for i in 0..7 { - terms.push((layout.instr_bit(i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - { - // rd_field = bits[7..11] - let mut terms = vec![(layout.rd_field(j), F::ONE)]; - for i in 0..5 { - terms.push((layout.instr_bit(7 + i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - { - // funct3 = bits[12..14] - let mut terms = vec![(layout.funct3(j), F::ONE)]; - for i in 0..3 { - terms.push((layout.instr_bit(12 + i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - { - // rs1_field = bits[15..19] - let mut terms = vec![(layout.rs1_field(j), F::ONE)]; - for i in 0..5 { - terms.push((layout.instr_bit(15 + i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - { - // rs2_field = bits[20..24] - let mut terms = vec![(layout.rs2_field(j), F::ONE)]; - for i in 0..5 { - terms.push((layout.instr_bit(20 + i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - { - // funct7 = bits[25..31] - let mut terms = vec![(layout.funct7(j), F::ONE)]; - for i in 0..7 { - terms.push((layout.instr_bit(25 + i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // imm12_raw = bits[20..31] (unsigned 12-bit) - { - let mut terms = vec![(layout.imm12_raw(j), F::ONE)]; - for i in 0..12 { - terms.push((layout.instr_bit(20 + i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // imm_i (u32 representation): imm12_raw + sign*(2^32 - 2^12) - { - let sign = layout.instr_bit(31, j); - let bias = (1u64 << 32) - (1u64 << 12); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.imm_i(j), F::ONE), - (layout.imm12_raw(j), -F::ONE), - (sign, -F::from_u64(bias)), - ], - )); - } - - // imm_s (u32 representation): - // low5 = bits[7..11] (already packed as rd_field) - // high7 = bits[25..31] at positions [5..11] - // imm_s = low5 + Σ 2^(5+i)*bits[25+i] + sign*(2^32 - 2^12) - { - let sign = layout.instr_bit(31, j); - let bias = (1u64 << 32) - (1u64 << 12); - let mut terms = vec![ - (layout.imm_s(j), F::ONE), - (layout.rd_field(j), -F::ONE), - (sign, -F::from_u64(bias)), - ]; - for i in 0..7 { - terms.push((layout.instr_bit(25 + i, j), -F::from_u64(pow2_u64(5 + i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // imm_u (already << 12): Σ_{i=12..31} 2^i * bit[i] - { - let mut terms = vec![(layout.imm_u(j), F::ONE)]; - for i in 12..32 { - terms.push((layout.instr_bit(i, j), -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // imm_b_raw (unsigned 13-bit, bit 0 is 0): - // imm[12] = bit31 - // imm[11] = bit7 - // imm[10:5] = bits[30:25] - // imm[4:1] = bits[11:8] - { - let mut terms = vec![(layout.imm_b_raw(j), F::ONE)]; - terms.push((layout.instr_bit(31, j), -F::from_u64(pow2_u64(12)))); - terms.push((layout.instr_bit(7, j), -F::from_u64(pow2_u64(11)))); - for i in 0..6 { - terms.push((layout.instr_bit(25 + i, j), -F::from_u64(pow2_u64(5 + i)))); - } - for i in 0..4 { - terms.push((layout.instr_bit(8 + i, j), -F::from_u64(pow2_u64(1 + i)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // imm_b (signed i32, as field element): imm_b = imm_b_raw - sign*2^13. - { - let sign = layout.instr_bit(31, j); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.imm_b(j), F::ONE), - (layout.imm_b_raw(j), -F::ONE), - (sign, F::from_u64(pow2_u64(13))), - ], - )); - } - - // imm_j_raw (unsigned 21-bit, bit 0 is 0): - // imm[20] = bit31 - // imm[19:12] = bits[19:12] - // imm[11] = bit20 - // imm[10:1] = bits[30:21] - { - let mut terms = vec![(layout.imm_j_raw(j), F::ONE)]; - terms.push((layout.instr_bit(31, j), -F::from_u64(pow2_u64(20)))); - for i in 12..20 { - terms.push((layout.instr_bit(i, j), -F::from_u64(pow2_u64(i)))); - } - terms.push((layout.instr_bit(20, j), -F::from_u64(pow2_u64(11)))); - for i in 21..31 { - terms.push((layout.instr_bit(i, j), -F::from_u64(pow2_u64(i - 20)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // imm_j (signed i32, as field element): imm_j = imm_j_raw - sign*2^21. - { - let sign = layout.instr_bit(31, j); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.imm_j(j), F::ONE), - (layout.imm_j_raw(j), -F::ONE), - (sign, F::from_u64(pow2_u64(21))), - ], - )); - } - - // Flags: boolean + one-hot. - let flags = [ - layout.is_add(j), - layout.is_sub(j), - layout.is_sll(j), - layout.is_slt(j), - layout.is_sltu(j), - layout.is_xor(j), - layout.is_srl(j), - layout.is_sra(j), - layout.is_or(j), - layout.is_and(j), - layout.is_mul(j), - layout.is_mulh(j), - layout.is_mulhu(j), - layout.is_mulhsu(j), - layout.is_div(j), - layout.is_divu(j), - layout.is_rem(j), - layout.is_remu(j), - layout.is_addi(j), - layout.is_slti(j), - layout.is_sltiu(j), - layout.is_xori(j), - layout.is_ori(j), - layout.is_andi(j), - layout.is_slli(j), - layout.is_srli(j), - layout.is_srai(j), - layout.is_lb(j), - layout.is_lbu(j), - layout.is_lh(j), - layout.is_lhu(j), - layout.is_lw(j), - layout.is_sb(j), - layout.is_sh(j), - layout.is_sw(j), - layout.is_amoswap_w(j), - layout.is_amoadd_w(j), - layout.is_amoxor_w(j), - layout.is_amoor_w(j), - layout.is_amoand_w(j), - layout.is_lui(j), - layout.is_auipc(j), - layout.is_beq(j), - layout.is_bne(j), - layout.is_blt(j), - layout.is_bge(j), - layout.is_bltu(j), - layout.is_bgeu(j), - layout.is_jal(j), - layout.is_jalr(j), - layout.is_fence(j), - layout.is_halt(j), - ]; - for &f in &flags { - constraints.push(Constraint::terms(f, false, vec![(f, F::ONE), (is_active, -F::ONE)])); - } - { - let mut terms = Vec::with_capacity(flags.len() + 1); - for &f in &flags { - terms.push((f, F::ONE)); - } - terms.push((is_active, -F::ONE)); - constraints.push(Constraint::terms(one, false, terms)); - } - - // Decode constraints for the supported RV32I/M core subset. - constraints.push(Constraint::eq_const(layout.is_add(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::zero(layout.is_add(j), layout.funct3(j))); - constraints.push(Constraint::zero(layout.is_add(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_sub(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::zero(layout.is_sub(j), layout.funct3(j))); - constraints.push(Constraint::eq_const(layout.is_sub(j), one, layout.funct7(j), 0x20)); - - constraints.push(Constraint::eq_const(layout.is_sll(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_sll(j), one, layout.funct3(j), 0x1)); - constraints.push(Constraint::zero(layout.is_sll(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_slt(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_slt(j), one, layout.funct3(j), 0x2)); - constraints.push(Constraint::zero(layout.is_slt(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_sltu(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_sltu(j), one, layout.funct3(j), 0x3)); - constraints.push(Constraint::zero(layout.is_sltu(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_xor(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_xor(j), one, layout.funct3(j), 0x4)); - constraints.push(Constraint::zero(layout.is_xor(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_srl(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_srl(j), one, layout.funct3(j), 0x5)); - constraints.push(Constraint::zero(layout.is_srl(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_sra(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_sra(j), one, layout.funct3(j), 0x5)); - constraints.push(Constraint::eq_const(layout.is_sra(j), one, layout.funct7(j), 0x20)); - - constraints.push(Constraint::eq_const(layout.is_or(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_or(j), one, layout.funct3(j), 0x6)); - constraints.push(Constraint::zero(layout.is_or(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_and(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_and(j), one, layout.funct3(j), 0x7)); - constraints.push(Constraint::zero(layout.is_and(j), layout.funct7(j))); - - // RV32M (funct7 = 0b0000001). - constraints.push(Constraint::eq_const(layout.is_mul(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::zero(layout.is_mul(j), layout.funct3(j))); - constraints.push(Constraint::eq_const(layout.is_mul(j), one, layout.funct7(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_mulh(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_mulh(j), one, layout.funct3(j), 0x1)); - constraints.push(Constraint::eq_const(layout.is_mulh(j), one, layout.funct7(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_mulhsu(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_mulhsu(j), one, layout.funct3(j), 0x2)); - constraints.push(Constraint::eq_const(layout.is_mulhsu(j), one, layout.funct7(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_mulhu(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_mulhu(j), one, layout.funct3(j), 0x3)); - constraints.push(Constraint::eq_const(layout.is_mulhu(j), one, layout.funct7(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_div(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_div(j), one, layout.funct3(j), 0x4)); - constraints.push(Constraint::eq_const(layout.is_div(j), one, layout.funct7(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_divu(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_divu(j), one, layout.funct3(j), 0x5)); - constraints.push(Constraint::eq_const(layout.is_divu(j), one, layout.funct7(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_rem(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_rem(j), one, layout.funct3(j), 0x6)); - constraints.push(Constraint::eq_const(layout.is_rem(j), one, layout.funct7(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_remu(j), one, layout.opcode(j), 0x33)); - constraints.push(Constraint::eq_const(layout.is_remu(j), one, layout.funct3(j), 0x7)); - constraints.push(Constraint::eq_const(layout.is_remu(j), one, layout.funct7(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_addi(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::zero(layout.is_addi(j), layout.funct3(j))); - - constraints.push(Constraint::eq_const(layout.is_slti(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::eq_const(layout.is_slti(j), one, layout.funct3(j), 0x2)); - - constraints.push(Constraint::eq_const(layout.is_sltiu(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::eq_const(layout.is_sltiu(j), one, layout.funct3(j), 0x3)); - - constraints.push(Constraint::eq_const(layout.is_xori(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::eq_const(layout.is_xori(j), one, layout.funct3(j), 0x4)); - - constraints.push(Constraint::eq_const(layout.is_ori(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::eq_const(layout.is_ori(j), one, layout.funct3(j), 0x6)); - - constraints.push(Constraint::eq_const(layout.is_andi(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::eq_const(layout.is_andi(j), one, layout.funct3(j), 0x7)); - - constraints.push(Constraint::eq_const(layout.is_slli(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::eq_const(layout.is_slli(j), one, layout.funct3(j), 0x1)); - constraints.push(Constraint::zero(layout.is_slli(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_srli(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::eq_const(layout.is_srli(j), one, layout.funct3(j), 0x5)); - constraints.push(Constraint::zero(layout.is_srli(j), layout.funct7(j))); - - constraints.push(Constraint::eq_const(layout.is_srai(j), one, layout.opcode(j), 0x13)); - constraints.push(Constraint::eq_const(layout.is_srai(j), one, layout.funct3(j), 0x5)); - constraints.push(Constraint::eq_const(layout.is_srai(j), one, layout.funct7(j), 0x20)); - - constraints.push(Constraint::eq_const(layout.is_lb(j), one, layout.opcode(j), 0x03)); - constraints.push(Constraint::eq_const(layout.is_lb(j), one, layout.funct3(j), 0x0)); - - constraints.push(Constraint::eq_const(layout.is_lh(j), one, layout.opcode(j), 0x03)); - constraints.push(Constraint::eq_const(layout.is_lh(j), one, layout.funct3(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_lw(j), one, layout.opcode(j), 0x03)); - constraints.push(Constraint::eq_const(layout.is_lw(j), one, layout.funct3(j), 0x2)); - - constraints.push(Constraint::eq_const(layout.is_lbu(j), one, layout.opcode(j), 0x03)); - constraints.push(Constraint::eq_const(layout.is_lbu(j), one, layout.funct3(j), 0x4)); - - constraints.push(Constraint::eq_const(layout.is_lhu(j), one, layout.opcode(j), 0x03)); - constraints.push(Constraint::eq_const(layout.is_lhu(j), one, layout.funct3(j), 0x5)); - - constraints.push(Constraint::eq_const(layout.is_sb(j), one, layout.opcode(j), 0x23)); - constraints.push(Constraint::eq_const(layout.is_sb(j), one, layout.funct3(j), 0x0)); - - constraints.push(Constraint::eq_const(layout.is_sh(j), one, layout.opcode(j), 0x23)); - constraints.push(Constraint::eq_const(layout.is_sh(j), one, layout.funct3(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_sw(j), one, layout.opcode(j), 0x23)); - constraints.push(Constraint::eq_const(layout.is_sw(j), one, layout.funct3(j), 0x2)); - - // RV32A atomics (AMO*, word only): opcode=0x2F, funct3=010, funct5 in bits [31:27]. - constraints.push(Constraint::eq_const( - layout.is_amoswap_w(j), - one, - layout.opcode(j), - 0x2f, - )); - constraints.push(Constraint::eq_const(layout.is_amoswap_w(j), one, layout.funct3(j), 0x2)); - constraints.push(Constraint::terms( - layout.is_amoswap_w(j), - false, - vec![(layout.instr_bit(27, j), F::ONE), (one, -F::ONE)], - )); - constraints.push(Constraint::zero(layout.is_amoswap_w(j), layout.instr_bit(28, j))); - constraints.push(Constraint::zero(layout.is_amoswap_w(j), layout.instr_bit(29, j))); - constraints.push(Constraint::zero(layout.is_amoswap_w(j), layout.instr_bit(30, j))); - constraints.push(Constraint::zero(layout.is_amoswap_w(j), layout.instr_bit(31, j))); - - constraints.push(Constraint::eq_const(layout.is_amoadd_w(j), one, layout.opcode(j), 0x2f)); - constraints.push(Constraint::eq_const(layout.is_amoadd_w(j), one, layout.funct3(j), 0x2)); - constraints.push(Constraint::zero(layout.is_amoadd_w(j), layout.instr_bit(27, j))); - constraints.push(Constraint::zero(layout.is_amoadd_w(j), layout.instr_bit(28, j))); - constraints.push(Constraint::zero(layout.is_amoadd_w(j), layout.instr_bit(29, j))); - constraints.push(Constraint::zero(layout.is_amoadd_w(j), layout.instr_bit(30, j))); - constraints.push(Constraint::zero(layout.is_amoadd_w(j), layout.instr_bit(31, j))); - - constraints.push(Constraint::eq_const(layout.is_amoxor_w(j), one, layout.opcode(j), 0x2f)); - constraints.push(Constraint::eq_const(layout.is_amoxor_w(j), one, layout.funct3(j), 0x2)); - constraints.push(Constraint::zero(layout.is_amoxor_w(j), layout.instr_bit(27, j))); - constraints.push(Constraint::zero(layout.is_amoxor_w(j), layout.instr_bit(28, j))); - constraints.push(Constraint::terms( - layout.is_amoxor_w(j), - false, - vec![(layout.instr_bit(29, j), F::ONE), (one, -F::ONE)], - )); - constraints.push(Constraint::zero(layout.is_amoxor_w(j), layout.instr_bit(30, j))); - constraints.push(Constraint::zero(layout.is_amoxor_w(j), layout.instr_bit(31, j))); - - constraints.push(Constraint::eq_const(layout.is_amoor_w(j), one, layout.opcode(j), 0x2f)); - constraints.push(Constraint::eq_const(layout.is_amoor_w(j), one, layout.funct3(j), 0x2)); - constraints.push(Constraint::zero(layout.is_amoor_w(j), layout.instr_bit(27, j))); - constraints.push(Constraint::zero(layout.is_amoor_w(j), layout.instr_bit(28, j))); - constraints.push(Constraint::zero(layout.is_amoor_w(j), layout.instr_bit(29, j))); - constraints.push(Constraint::terms( - layout.is_amoor_w(j), - false, - vec![(layout.instr_bit(30, j), F::ONE), (one, -F::ONE)], - )); - constraints.push(Constraint::zero(layout.is_amoor_w(j), layout.instr_bit(31, j))); - - constraints.push(Constraint::eq_const(layout.is_amoand_w(j), one, layout.opcode(j), 0x2f)); - constraints.push(Constraint::eq_const(layout.is_amoand_w(j), one, layout.funct3(j), 0x2)); - constraints.push(Constraint::zero(layout.is_amoand_w(j), layout.instr_bit(27, j))); - constraints.push(Constraint::zero(layout.is_amoand_w(j), layout.instr_bit(28, j))); - constraints.push(Constraint::terms( - layout.is_amoand_w(j), - false, - vec![(layout.instr_bit(29, j), F::ONE), (one, -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.is_amoand_w(j), - false, - vec![(layout.instr_bit(30, j), F::ONE), (one, -F::ONE)], - )); - constraints.push(Constraint::zero(layout.is_amoand_w(j), layout.instr_bit(31, j))); - - constraints.push(Constraint::eq_const(layout.is_lui(j), one, layout.opcode(j), 0x37)); - constraints.push(Constraint::eq_const(layout.is_auipc(j), one, layout.opcode(j), 0x17)); - - constraints.push(Constraint::eq_const(layout.is_beq(j), one, layout.opcode(j), 0x63)); - constraints.push(Constraint::zero(layout.is_beq(j), layout.funct3(j))); - - constraints.push(Constraint::eq_const(layout.is_bne(j), one, layout.opcode(j), 0x63)); - constraints.push(Constraint::eq_const(layout.is_bne(j), one, layout.funct3(j), 0x1)); - - constraints.push(Constraint::eq_const(layout.is_blt(j), one, layout.opcode(j), 0x63)); - constraints.push(Constraint::eq_const(layout.is_blt(j), one, layout.funct3(j), 0x4)); - - constraints.push(Constraint::eq_const(layout.is_bge(j), one, layout.opcode(j), 0x63)); - constraints.push(Constraint::eq_const(layout.is_bge(j), one, layout.funct3(j), 0x5)); - - constraints.push(Constraint::eq_const(layout.is_bltu(j), one, layout.opcode(j), 0x63)); - constraints.push(Constraint::eq_const(layout.is_bltu(j), one, layout.funct3(j), 0x6)); - - constraints.push(Constraint::eq_const(layout.is_bgeu(j), one, layout.opcode(j), 0x63)); - constraints.push(Constraint::eq_const(layout.is_bgeu(j), one, layout.funct3(j), 0x7)); - - constraints.push(Constraint::eq_const(layout.is_jal(j), one, layout.opcode(j), 0x6f)); - - constraints.push(Constraint::eq_const(layout.is_jalr(j), one, layout.opcode(j), 0x67)); - constraints.push(Constraint::zero(layout.is_jalr(j), layout.funct3(j))); - - constraints.push(Constraint::eq_const(layout.is_fence(j), one, layout.opcode(j), 0x0f)); - constraints.push(Constraint::zero(layout.is_fence(j), layout.funct3(j))); - - constraints.push(Constraint::eq_const(layout.is_halt(j), one, layout.opcode(j), 0x73)); - constraints.push(Constraint::zero(layout.is_halt(j), layout.imm12_raw(j))); - constraints.push(Constraint::zero(layout.is_halt(j), layout.rd_field(j))); - constraints.push(Constraint::zero(layout.is_halt(j), layout.rs1_field(j))); - constraints.push(Constraint::zero(layout.is_halt(j), layout.funct3(j))); - - // Selector one-hots (rs1/rs2 always derived from rs1_field/rs2_field). - for r in 0..32 { - let b1 = layout.rs1_sel(r, j); - let b2 = layout.rs2_sel(r, j); - let bd = layout.rd_sel(r, j); - constraints.push(Constraint::terms(b1, false, vec![(b1, F::ONE), (is_active, -F::ONE)])); - constraints.push(Constraint::terms(b2, false, vec![(b2, F::ONE), (is_active, -F::ONE)])); - constraints.push(Constraint::terms(bd, false, vec![(bd, F::ONE), (is_active, -F::ONE)])); - } - for sels in ["rs1", "rs2", "rd"] { - let mut terms = Vec::with_capacity(33); - for r in 0..32 { - let col = match sels { - "rs1" => layout.rs1_sel(r, j), - "rs2" => layout.rs2_sel(r, j), - _ => layout.rd_sel(r, j), - }; - terms.push((col, F::ONE)); - } - terms.push((is_active, -F::ONE)); - constraints.push(Constraint::terms(one, false, terms)); - } - - // rs1_field == Σ r * rs1_sel[r] - { - let mut terms = vec![(layout.rs1_field(j), F::ONE)]; - for r in 0..32 { - terms.push((layout.rs1_sel(r, j), -F::from_u64(r as u64))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - // rs2_field == Σ r * rs2_sel[r] - { - let mut terms = vec![(layout.rs2_field(j), F::ONE)]; - for r in 0..32 { - terms.push((layout.rs2_sel(r, j), -F::from_u64(r as u64))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // rd_field == Σ r * rd_sel[r] when instruction writes rd. - let writes_rd_flags = [ - layout.is_add(j), - layout.is_sub(j), - layout.is_sll(j), - layout.is_slt(j), - layout.is_sltu(j), - layout.is_xor(j), - layout.is_srl(j), - layout.is_sra(j), - layout.is_or(j), - layout.is_and(j), - layout.is_mul(j), - layout.is_mulh(j), - layout.is_mulhu(j), - layout.is_mulhsu(j), - layout.is_div(j), - layout.is_divu(j), - layout.is_rem(j), - layout.is_remu(j), - layout.is_addi(j), - layout.is_slti(j), - layout.is_sltiu(j), - layout.is_xori(j), - layout.is_ori(j), - layout.is_andi(j), - layout.is_slli(j), - layout.is_srli(j), - layout.is_srai(j), - layout.is_lb(j), - layout.is_lbu(j), - layout.is_lh(j), - layout.is_lhu(j), - layout.is_lw(j), - layout.is_amoswap_w(j), - layout.is_amoadd_w(j), - layout.is_amoxor_w(j), - layout.is_amoor_w(j), - layout.is_amoand_w(j), - layout.is_lui(j), - layout.is_auipc(j), - layout.is_jal(j), - layout.is_jalr(j), - ]; - { - let mut terms = vec![(layout.rd_field(j), F::ONE)]; - for r in 0..32 { - terms.push((layout.rd_sel(r, j), -F::from_u64(r as u64))); - } - constraints.push(Constraint::terms_or(&writes_rd_flags, false, terms)); - } - - // If NOT writing rd, force rd_sel[1..] = 0 (so rd_sel == x0). - for r in 1..32 { - constraints.push(Constraint::terms_or( - &writes_rd_flags, - true, // (1 - writes_rd) - vec![(layout.rd_sel(r, j), F::ONE)], - )); - } - - // Bind rs1_val / rs2_val to regs_in via one-hot selectors. - for r in 0..32 { - constraints.push(Constraint::terms( - layout.rs1_sel(r, j), - false, - vec![(layout.rs1_val(j), F::ONE), (layout.reg_in(r, j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.rs2_sel(r, j), - false, - vec![(layout.rs2_val(j), F::ONE), (layout.reg_in(r, j), -F::ONE)], - )); - } - - // ECALL helpers (Jolt marker/print IDs). - let a0 = layout.reg_in(10, j); - let ecall_is_cycle = layout.ecall_is_cycle(j); - let ecall_is_print = layout.ecall_is_print(j); - let ecall_halts = layout.ecall_halts(j); - let halt_effective = layout.halt_effective(j); - - // Decompose a0 into bits. - for bit in 0..32 { - let b = layout.ecall_a0_bit(bit, j); - constraints.push(Constraint { - condition_col: b, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (b, -F::ONE)], // 1 - b - c_terms: Vec::new(), - }); - } - { - let mut terms = vec![(a0, F::ONE)]; - for bit in 0..32 { - terms.push((layout.ecall_a0_bit(bit, j), -F::from_u64(pow2_u64(bit)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // ecall_is_cycle = (a0 == JOLT_CYCLE_TRACK_ECALL_NUM). - let cycle_const = JOLT_CYCLE_TRACK_ECALL_NUM as u32; - { - let bit0 = layout.ecall_a0_bit(0, j); - let prefix0 = layout.ecall_cycle_prefix(0, j); - if (cycle_const & 1) == 1 { - constraints.push(Constraint::terms(one, false, vec![(prefix0, F::ONE), (bit0, -F::ONE)])); - } else { - constraints.push(Constraint::terms( - one, - false, - vec![(prefix0, F::ONE), (bit0, F::ONE), (one, -F::ONE)], - )); - } - } - for k in 1..31 { - let prev = layout.ecall_cycle_prefix(k - 1, j); - let next = layout.ecall_cycle_prefix(k, j); - let bit_col = layout.ecall_a0_bit(k, j); - let bit_is_one = ((cycle_const >> k) & 1) == 1; - let b_terms = if bit_is_one { - vec![(bit_col, F::ONE)] - } else { - vec![(one, F::ONE), (bit_col, -F::ONE)] - }; - constraints.push(Constraint { - condition_col: prev, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms, - c_terms: vec![(next, F::ONE)], - }); - } - { - let prev = layout.ecall_cycle_prefix(30, j); - let bit_col = layout.ecall_a0_bit(31, j); - let bit_is_one = ((cycle_const >> 31) & 1) == 1; - let b_terms = if bit_is_one { - vec![(bit_col, F::ONE)] - } else { - vec![(one, F::ONE), (bit_col, -F::ONE)] - }; - constraints.push(Constraint { - condition_col: prev, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms, - c_terms: vec![(ecall_is_cycle, F::ONE)], - }); - } - constraints.push(Constraint::mul(ecall_is_cycle, ecall_is_cycle, ecall_is_cycle)); - - // ecall_is_print = (a0 == JOLT_PRINT_ECALL_NUM). - let print_const = JOLT_PRINT_ECALL_NUM as u32; - { - let bit0 = layout.ecall_a0_bit(0, j); - let prefix0 = layout.ecall_print_prefix(0, j); - if (print_const & 1) == 1 { - constraints.push(Constraint::terms(one, false, vec![(prefix0, F::ONE), (bit0, -F::ONE)])); - } else { - constraints.push(Constraint::terms( - one, - false, - vec![(prefix0, F::ONE), (bit0, F::ONE), (one, -F::ONE)], - )); - } - } - for k in 1..31 { - let prev = layout.ecall_print_prefix(k - 1, j); - let next = layout.ecall_print_prefix(k, j); - let bit_col = layout.ecall_a0_bit(k, j); - let bit_is_one = ((print_const >> k) & 1) == 1; - let b_terms = if bit_is_one { - vec![(bit_col, F::ONE)] - } else { - vec![(one, F::ONE), (bit_col, -F::ONE)] - }; - constraints.push(Constraint { - condition_col: prev, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms, - c_terms: vec![(next, F::ONE)], - }); - } - { - let prev = layout.ecall_print_prefix(30, j); - let bit_col = layout.ecall_a0_bit(31, j); - let bit_is_one = ((print_const >> 31) & 1) == 1; - let b_terms = if bit_is_one { - vec![(bit_col, F::ONE)] - } else { - vec![(one, F::ONE), (bit_col, -F::ONE)] - }; - constraints.push(Constraint { - condition_col: prev, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms, - c_terms: vec![(ecall_is_print, F::ONE)], - }); - } - constraints.push(Constraint::mul(ecall_is_print, ecall_is_print, ecall_is_print)); - - constraints.push(Constraint::terms( - one, - false, - vec![ - (ecall_halts, F::ONE), - (ecall_is_cycle, F::ONE), - (ecall_is_print, F::ONE), - (one, -F::ONE), - ], - )); - constraints.push(Constraint::mul(ecall_halts, ecall_halts, ecall_halts)); - constraints.push(Constraint::mul(layout.is_halt(j), ecall_halts, halt_effective)); - - // -------------------------------------------------------------------- - // RV32M helpers (in-circuit MUL/DIV/REM) - // -------------------------------------------------------------------- - - // Range-check rd_write_val to 32 bits (keeps writeback canonical). - enforce_u32_bits( - &mut constraints, - one, - layout.rd_write_val(j), - layout.rd_write_bits_start, - layout.chunk_size, - j, - ); - - // Range-check mem_rv to 32 bits so byte/half extraction is sound. - enforce_u32_bits( - &mut constraints, - one, - layout.mem_rv(j), - layout.mem_rv_bits_start, - layout.chunk_size, - j, - ); - - // Range-check mul_lo / mul_hi to ensure the decomposition is unique. - enforce_u32_bits( - &mut constraints, - one, - layout.mul_lo(j), - layout.mul_lo_bits_start, - layout.chunk_size, - j, - ); - enforce_u32_bits( - &mut constraints, - one, - layout.mul_hi(j), - layout.mul_hi_bits_start, - layout.chunk_size, - j, - ); - - // Disambiguate the MUL decomposition in Goldilocks by ruling out `mul_hi == 0xffff_ffff`. - // - // For 32-bit operands, the true 64-bit product has `mul_hi <= 0xffff_fffe`. Without this, - // the field equation `rs1*rs2 = mul_lo + 2^32*mul_hi (mod p)` also admits the solution - // `mul_lo + 2^32*mul_hi = rs1*rs2 + p` when `rs1*rs2 <= 2^32-2`, where `p = 2^64-2^32+1`. - // - // We enforce `mul_hi != 0xffff_ffff` by constraining `∏_{i=0..31} mul_hi_bit[i] = 0`. - constraints.push(Constraint::terms( - one, - false, - vec![(layout.mul_hi_prefix(0, j), F::ONE), (layout.mul_hi_bit(0, j), -F::ONE)], - )); - for k in 1..31 { - constraints.push(Constraint::mul( - layout.mul_hi_prefix(k - 1, j), - layout.mul_hi_bit(k, j), - layout.mul_hi_prefix(k, j), - )); - } - constraints.push(Constraint { - condition_col: layout.mul_hi_prefix(30, j), - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(layout.mul_hi_bit(31, j), F::ONE)], - c_terms: Vec::new(), - }); - - // mul_carry bits (0..3, but only 0..2 will satisfy the MULH equations). - for bit in 0..2 { - let b = layout.mul_carry_bit(bit, j); - constraints.push(Constraint { - condition_col: b, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (b, -F::ONE)], // 1 - b - c_terms: Vec::new(), - }); - } - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.mul_carry(j), F::ONE), - (layout.mul_carry_bit(0, j), -F::ONE), - (layout.mul_carry_bit(1, j), -F::from_u64(2)), - ], - )); - - // MUL decomposition (always enforced): rs1_val * rs2_val = mul_lo + 2^32 * mul_hi. - constraints.push(Constraint { - condition_col: layout.rs1_val(j), - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(layout.rs2_val(j), F::ONE)], - c_terms: vec![ - (layout.mul_lo(j), F::ONE), - (layout.mul_hi(j), F::from_u64(pow2_u64(32))), - ], - }); - - // MUL/MULHU writeback. - constraints.push(Constraint::terms( - layout.is_mul(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.mul_lo(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.is_mulhu(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.mul_hi(j), -F::ONE)], - )); - - // rs2_bit[i] ∈ {0,1} - for bit in 0..32 { - let b = layout.rs2_bit(bit, j); - constraints.push(Constraint { - condition_col: b, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (b, -F::ONE)], // 1 - b - c_terms: Vec::new(), - }); - } - // rs1_bit[i] ∈ {0,1} - for bit in 0..32 { - let b = layout.rs1_bit(bit, j); - constraints.push(Constraint { - condition_col: b, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (b, -F::ONE)], // 1 - b - c_terms: Vec::new(), - }); - } - - // rs2_val = Σ 2^i * rs2_bit[i] - { - let mut terms = vec![(layout.rs2_val(j), F::ONE)]; - for bit in 0..32 { - terms.push((layout.rs2_bit(bit, j), -F::from_u64(pow2_u64(bit)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - // rs1_val = Σ 2^i * rs1_bit[i] - { - let mut terms = vec![(layout.rs1_val(j), F::ONE)]; - for bit in 0..32 { - terms.push((layout.rs1_bit(bit, j), -F::from_u64(pow2_u64(bit)))); - } - constraints.push(Constraint::terms(one, false, terms)); - } - - // Prefix product chain for Π_{i=0..31} (1 - rs2_bit[i]). - // prefix[0] = (1 - b0) - constraints.push(Constraint { - condition_col: one, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (layout.rs2_bit(0, j), -F::ONE)], - c_terms: vec![(layout.rs2_zero_prefix(0, j), F::ONE)], - }); - // prefix[k] = prefix[k-1] * (1 - b_k) for k=1..30 - for k in 1..31 { - constraints.push(Constraint { - condition_col: layout.rs2_zero_prefix(k - 1, j), - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (layout.rs2_bit(k, j), -F::ONE)], - c_terms: vec![(layout.rs2_zero_prefix(k, j), F::ONE)], - }); - } - // rs2_is_zero = prefix[30] * (1 - b_31) - constraints.push(Constraint { - condition_col: layout.rs2_zero_prefix(30, j), - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (layout.rs2_bit(31, j), -F::ONE)], - c_terms: vec![(layout.rs2_is_zero(j), F::ONE)], - }); - - // rs2_nonzero = 1 - rs2_is_zero. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.rs2_nonzero(j), F::ONE), - (layout.rs2_is_zero(j), F::ONE), - (one, -F::ONE), - ], - )); - - let rs1_sign = layout.rs1_bit(31, j); - let rs2_sign = layout.rs2_bit(31, j); - - // rs1_abs / rs2_abs from two's-complement sign bits. - constraints.push(Constraint::terms( - rs1_sign, - true, - vec![(layout.rs1_abs(j), F::ONE), (layout.rs1_val(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - rs1_sign, - false, - vec![ - (layout.rs1_abs(j), F::ONE), - (layout.rs1_val(j), F::ONE), - (one, -F::from_u64(pow2_u64(32))), - ], - )); - constraints.push(Constraint::terms( - rs2_sign, - true, - vec![(layout.rs2_abs(j), F::ONE), (layout.rs2_val(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - rs2_sign, - false, - vec![ - (layout.rs2_abs(j), F::ONE), - (layout.rs2_val(j), F::ONE), - (one, -F::from_u64(pow2_u64(32))), - ], - )); - - // Sign helpers. - constraints.push(Constraint::mul(rs1_sign, rs2_sign, layout.rs1_rs2_sign_and(j))); - constraints.push(Constraint::mul(rs1_sign, layout.rs2_val(j), layout.rs1_sign_rs2_val(j))); - constraints.push(Constraint::mul(rs2_sign, layout.rs1_val(j), layout.rs2_sign_rs1_val(j))); - - // MULH/MULHSU writeback with signed correction. - constraints.push(Constraint::terms( - layout.is_mulh(j), - false, - vec![ - (layout.rd_write_val(j), F::ONE), - (layout.mul_carry(j), F::from_u64(pow2_u64(32))), - (layout.mul_hi(j), -F::ONE), - (layout.rs1_sign_rs2_val(j), F::ONE), - (layout.rs2_sign_rs1_val(j), F::ONE), - (layout.rs1_rs2_sign_and(j), -F::from_u64(pow2_u64(32))), - (one, -F::from_u64(pow2_u64(32))), - ], - )); - constraints.push(Constraint::terms( - layout.is_mulhsu(j), - false, - vec![ - (layout.rd_write_val(j), F::ONE), - (layout.mul_carry(j), F::from_u64(pow2_u64(32))), - (layout.mul_hi(j), -F::ONE), - (layout.rs1_sign_rs2_val(j), F::ONE), - (one, -F::from_u64(pow2_u64(32))), - ], - )); - - // is_divu_or_remu = is_divu + is_remu. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.is_divu_or_remu(j), F::ONE), - (layout.is_divu(j), -F::ONE), - (layout.is_remu(j), -F::ONE), - ], - )); - - // is_div_or_rem = is_div + is_rem. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.is_div_or_rem(j), F::ONE), - (layout.is_div(j), -F::ONE), - (layout.is_rem(j), -F::ONE), - ], - )); - - // div_rem_check (unsigned) = is_divu_or_remu * rs2_nonzero. - constraints.push(Constraint::mul( - layout.is_divu_or_remu(j), - layout.rs2_nonzero(j), - layout.div_rem_check(j), - )); - - // div_rem_check_signed = is_div_or_rem * rs2_nonzero. - constraints.push(Constraint::mul( - layout.is_div_or_rem(j), - layout.rs2_nonzero(j), - layout.div_rem_check_signed(j), - )); - - // divu_by_zero = is_divu * rs2_is_zero. - constraints.push(Constraint::mul( - layout.is_divu(j), - layout.rs2_is_zero(j), - layout.divu_by_zero(j), - )); - - // div_by_zero / div_nonzero for signed DIV. - constraints.push(Constraint::mul( - layout.is_div(j), - layout.rs2_is_zero(j), - layout.div_by_zero(j), - )); - constraints.push(Constraint::mul( - layout.is_div(j), - layout.rs2_nonzero(j), - layout.div_nonzero(j), - )); - - // rem_nonzero / rem_by_zero for signed REM. - constraints.push(Constraint::mul( - layout.is_rem(j), - layout.rs2_nonzero(j), - layout.rem_nonzero(j), - )); - constraints.push(Constraint::mul( - layout.is_rem(j), - layout.rs2_is_zero(j), - layout.rem_by_zero(j), - )); - - // DIVU by zero: quotient must be all 1s. - constraints.push(Constraint::terms( - layout.divu_by_zero(j), - false, - vec![(layout.div_quot(j), F::ONE), (one, -F::from_u64(u32::MAX as u64))], - )); - - // div_divisor selects rs2_val (unsigned) or rs2_abs (signed). - constraints.push(Constraint::terms( - layout.is_divu_or_remu(j), - false, - vec![(layout.div_divisor(j), F::ONE), (layout.rs2_val(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.is_div_or_rem(j), - false, - vec![(layout.div_divisor(j), F::ONE), (layout.rs2_abs(j), -F::ONE)], - )); - - // div_prod = div_divisor * div_quot (always computed). - constraints.push(Constraint::mul( - layout.div_divisor(j), - layout.div_quot(j), - layout.div_prod(j), - )); - - // Unsigned: dividend = divisor * quotient + remainder. - constraints.push(Constraint::terms( - layout.is_divu_or_remu(j), - false, - vec![ - (layout.rs1_val(j), F::ONE), - (layout.div_prod(j), -F::ONE), - (layout.div_rem(j), -F::ONE), - ], - )); - - // Signed: |dividend| = |divisor| * quotient + remainder (divisor != 0). - constraints.push(Constraint::terms( - layout.div_rem_check_signed(j), - false, - vec![ - (layout.rs1_abs(j), F::ONE), - (layout.div_prod(j), -F::ONE), - (layout.div_rem(j), -F::ONE), - ], - )); - - // Range-check division outputs to keep quotient/remainder canonical. - enforce_u32_bits( - &mut constraints, - one, - layout.div_quot(j), - layout.div_quot_bits_start, - layout.chunk_size, - j, - ); - enforce_u32_bits( - &mut constraints, - one, - layout.div_rem(j), - layout.div_rem_bits_start, - layout.chunk_size, - j, - ); - - // div_sign = rs1_sign XOR rs2_sign. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.div_sign(j), F::ONE), - (rs1_sign, -F::ONE), - (rs2_sign, -F::ONE), - (layout.rs1_rs2_sign_and(j), F::from_u64(2)), - ], - )); - // div_sign boolean. - constraints.push(Constraint::terms( - layout.div_sign(j), - false, - vec![(layout.div_sign(j), F::ONE), (one, -F::ONE)], - )); - - // div_quot_carry / div_rem_carry bits (used to normalize negative zero). - for &carry in &[layout.div_quot_carry(j), layout.div_rem_carry(j)] { - constraints.push(Constraint { - condition_col: carry, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(one, F::ONE), (carry, -F::ONE)], // 1 - carry - c_terms: Vec::new(), - }); - } - // If sign=0, carry must be 0. - constraints.push(Constraint::terms( - layout.div_sign(j), - true, - vec![(layout.div_quot_carry(j), F::ONE)], - )); - constraints.push(Constraint::terms( - rs1_sign, - true, - vec![(layout.div_rem_carry(j), F::ONE)], - )); - - // Signed quotient / remainder (two's complement, with carry to allow -0 -> 0). - constraints.push(Constraint::terms( - layout.div_sign(j), - true, - vec![(layout.div_quot_signed(j), F::ONE), (layout.div_quot(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.div_sign(j), - false, - vec![ - (layout.div_quot_signed(j), F::ONE), - (layout.div_quot_carry(j), F::from_u64(pow2_u64(32))), - (layout.div_quot(j), F::ONE), - (one, -F::from_u64(pow2_u64(32))), - ], - )); - constraints.push(Constraint::terms( - rs1_sign, - true, - vec![(layout.div_rem_signed(j), F::ONE), (layout.div_rem(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - rs1_sign, - false, - vec![ - (layout.div_rem_signed(j), F::ONE), - (layout.div_rem_carry(j), F::from_u64(pow2_u64(32))), - (layout.div_rem(j), F::ONE), - (one, -F::from_u64(pow2_u64(32))), - ], - )); - - // Writeback for DIVU/REMU. - constraints.push(Constraint::terms( - layout.is_divu(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.div_quot(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.is_remu(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.div_rem(j), -F::ONE)], - )); - - // Writeback for DIV (signed): divisor != 0 uses signed quotient, divisor == 0 yields -1. - constraints.push(Constraint::terms( - layout.div_nonzero(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.div_quot_signed(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.div_by_zero(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (one, -F::from_u64(u32::MAX as u64))], - )); - - // Writeback for REM (signed): signed remainder (dividend sign). - constraints.push(Constraint::terms( - layout.is_rem(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.div_rem_signed(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.rem_by_zero(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.rs1_val(j), -F::ONE)], - )); - - // For divisor != 0, require remainder < divisor via a SLTU Shout lookup. - constraints.push(Constraint::terms( - layout.div_rem_check(j), - false, - vec![(layout.alu_out(j), F::ONE), (one, -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.div_rem_check_signed(j), - false, - vec![(layout.alu_out(j), F::ONE), (one, -F::ONE)], - )); - - // Register update pattern for r=1..31: - // - if rd_sel[r]=1 then reg_out[r] = rd_write_val - // - else reg_out[r] = reg_in[r] - for r in 1..32 { - constraints.push(Constraint::terms( - layout.rd_sel(r, j), - false, - vec![(layout.reg_out(r, j), F::ONE), (layout.rd_write_val(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.rd_sel(r, j), - true, - vec![(layout.reg_out(r, j), F::ONE), (layout.reg_in(r, j), -F::ONE)], - )); - } - - // RAM effective address is computed via the ADD Shout lookup (mod 2^32 semantics). - constraints.push(Constraint::terms_or( - &[ - layout.is_lb(j), - layout.is_lbu(j), - layout.is_lh(j), - layout.is_lhu(j), - layout.is_lw(j), - ], - false, - vec![(layout.eff_addr(j), F::ONE), (layout.alu_out(j), -F::ONE)], - )); - constraints.push(Constraint::terms_or( - &[layout.is_sb(j), layout.is_sh(j), layout.is_sw(j)], - false, - vec![(layout.eff_addr(j), F::ONE), (layout.alu_out(j), -F::ONE)], - )); - - // Atomics use rs1 as the effective address (no immediate). - constraints.push(Constraint::terms_or( - &[ - layout.is_amoswap_w(j), - layout.is_amoadd_w(j), - layout.is_amoxor_w(j), - layout.is_amoor_w(j), - layout.is_amoand_w(j), - ], - false, - vec![(layout.eff_addr(j), F::ONE), (layout.rs1_val(j), -F::ONE)], - )); - - // RAM bus selectors must be derived from instruction flags to avoid bypassing Twist. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.ram_has_read(j), F::ONE), - (layout.is_lb(j), -F::ONE), - (layout.is_lbu(j), -F::ONE), - (layout.is_lh(j), -F::ONE), - (layout.is_lhu(j), -F::ONE), - (layout.is_lw(j), -F::ONE), - (layout.is_sb(j), -F::ONE), - (layout.is_sh(j), -F::ONE), - (layout.is_amoswap_w(j), -F::ONE), - (layout.is_amoadd_w(j), -F::ONE), - (layout.is_amoxor_w(j), -F::ONE), - (layout.is_amoor_w(j), -F::ONE), - (layout.is_amoand_w(j), -F::ONE), - ], - )); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.ram_has_write(j), F::ONE), - (layout.is_sb(j), -F::ONE), - (layout.is_sh(j), -F::ONE), - (layout.is_sw(j), -F::ONE), - (layout.is_amoswap_w(j), -F::ONE), - (layout.is_amoadd_w(j), -F::ONE), - (layout.is_amoxor_w(j), -F::ONE), - (layout.is_amoor_w(j), -F::ONE), - (layout.is_amoand_w(j), -F::ONE), - ], - )); - - // RAM write value (WV): SW and AMOSWAP use rs2, other AMOs use a Shout output. - constraints.push(Constraint::terms_or( - &[layout.is_sw(j), layout.is_amoswap_w(j)], - false, - vec![(layout.ram_wv(j), F::ONE), (layout.rs2_val(j), -F::ONE)], - )); - // SB/SH write: merge low byte/halfword into the existing word. - { - let mut terms = vec![(layout.ram_wv(j), F::ONE), (layout.mem_rv(j), -F::ONE)]; - for bit in 0..8 { - let coeff = F::from_u64(pow2_u64(bit)); - terms.push((layout.mem_rv_bit(bit, j), coeff)); - terms.push((layout.rs2_bit(bit, j), -coeff)); - } - constraints.push(Constraint::terms(layout.is_sb(j), false, terms)); - } - { - let mut terms = vec![(layout.ram_wv(j), F::ONE), (layout.mem_rv(j), -F::ONE)]; - for bit in 0..16 { - let coeff = F::from_u64(pow2_u64(bit)); - terms.push((layout.mem_rv_bit(bit, j), coeff)); - terms.push((layout.rs2_bit(bit, j), -coeff)); - } - constraints.push(Constraint::terms(layout.is_sh(j), false, terms)); - } - for &f in &[ - layout.is_amoadd_w(j), - layout.is_amoxor_w(j), - layout.is_amoor_w(j), - layout.is_amoand_w(j), - ] { - constraints.push(Constraint::terms( - f, - false, - vec![(layout.ram_wv(j), F::ONE), (layout.alu_out(j), -F::ONE)], - )); - } - - // Shout selectors. - // ADD table: add_has_lookup = is_add + is_addi + loads/stores + is_amoadd_w + is_auipc + is_jalr. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.add_has_lookup(j), F::ONE), - (layout.is_add(j), -F::ONE), - (layout.is_addi(j), -F::ONE), - (layout.is_lb(j), -F::ONE), - (layout.is_lbu(j), -F::ONE), - (layout.is_lh(j), -F::ONE), - (layout.is_lhu(j), -F::ONE), - (layout.is_lw(j), -F::ONE), - (layout.is_sb(j), -F::ONE), - (layout.is_sh(j), -F::ONE), - (layout.is_sw(j), -F::ONE), - (layout.is_amoadd_w(j), -F::ONE), - (layout.is_auipc(j), -F::ONE), - (layout.is_jalr(j), -F::ONE), - ], - )); - - // AND/XOR/OR tables (R-type + I-type + AMO ops). - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.and_has_lookup(j), F::ONE), - (layout.is_and(j), -F::ONE), - (layout.is_andi(j), -F::ONE), - (layout.is_amoand_w(j), -F::ONE), - ], - )); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.xor_has_lookup(j), F::ONE), - (layout.is_xor(j), -F::ONE), - (layout.is_xori(j), -F::ONE), - (layout.is_amoxor_w(j), -F::ONE), - ], - )); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.or_has_lookup(j), F::ONE), - (layout.is_or(j), -F::ONE), - (layout.is_ori(j), -F::ONE), - (layout.is_amoor_w(j), -F::ONE), - ], - )); - - // Shift tables. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.sll_has_lookup(j), F::ONE), - (layout.is_sll(j), -F::ONE), - (layout.is_slli(j), -F::ONE), - ], - )); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.srl_has_lookup(j), F::ONE), - (layout.is_srl(j), -F::ONE), - (layout.is_srli(j), -F::ONE), - ], - )); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.sra_has_lookup(j), F::ONE), - (layout.is_sra(j), -F::ONE), - (layout.is_srai(j), -F::ONE), - ], - )); - - // SLT/SLTU tables (ALU + branch comparisons). - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.slt_has_lookup(j), F::ONE), - (layout.is_slt(j), -F::ONE), - (layout.is_slti(j), -F::ONE), - (layout.is_blt(j), -F::ONE), - (layout.is_bge(j), -F::ONE), - ], - )); - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.sltu_has_lookup(j), F::ONE), - (layout.is_sltu(j), -F::ONE), - (layout.is_sltiu(j), -F::ONE), - (layout.is_bltu(j), -F::ONE), - (layout.is_bgeu(j), -F::ONE), - (layout.div_rem_check(j), -F::ONE), - (layout.div_rem_check_signed(j), -F::ONE), - ], - )); - - // Twist RAM model (RV32 B1 / MVP): - // - RAM is byte-addressed (the Twist bus address is the architectural byte address). - // - Each Twist read/write is a full 32-bit word value at that byte address: the little-endian - // 4-byte window starting at `eff_addr`. - // - LB/LBU/LH/LHU derive rd_write_val from the low byte/halfword of `mem_rv`. - // - SB/SH are proven as read-modify-write over that same word window: `ram_wv` equals `mem_rv` - // with the low byte/halfword replaced by rs2's low bits. - // - // Alignment is enforced later via the low address bits on the RAM Twist lane. - - // Instruction-specific writeback: - // - RV32I ALU ops + AUIPC: rd_write_val = alu_out (verified via Shout) - // - Loads: rd_write_val derived from mem_rv (verified via Twist) - // - LUI: rd_write_val = imm_u (pure) - for &f in &[ - layout.is_add(j), - layout.is_sub(j), - layout.is_sll(j), - layout.is_slt(j), - layout.is_sltu(j), - layout.is_xor(j), - layout.is_srl(j), - layout.is_sra(j), - layout.is_or(j), - layout.is_and(j), - layout.is_addi(j), - layout.is_slti(j), - layout.is_sltiu(j), - layout.is_xori(j), - layout.is_ori(j), - layout.is_andi(j), - layout.is_slli(j), - layout.is_srli(j), - layout.is_srai(j), - layout.is_auipc(j), - ] { - constraints.push(Constraint::terms( - f, - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.alu_out(j), -F::ONE)], - )); - } - constraints.push(Constraint::terms_or( - &[ - layout.is_lw(j), - layout.is_amoswap_w(j), - layout.is_amoadd_w(j), - layout.is_amoxor_w(j), - layout.is_amoor_w(j), - layout.is_amoand_w(j), - ], - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.mem_rv(j), -F::ONE)], - )); - // LB: sign-extend low byte. - { - let mut terms = vec![(layout.rd_write_val(j), F::ONE)]; - for bit in 0..8 { - let coeff = if bit == 7 { - F::from_u64(pow2_u64(32) - pow2_u64(7)) - } else { - F::from_u64(pow2_u64(bit)) - }; - terms.push((layout.mem_rv_bit(bit, j), -coeff)); - } - constraints.push(Constraint::terms(layout.is_lb(j), false, terms)); - } - // LBU: zero-extend low byte. - { - let mut terms = vec![(layout.rd_write_val(j), F::ONE)]; - for bit in 0..8 { - terms.push((layout.mem_rv_bit(bit, j), -F::from_u64(pow2_u64(bit)))); - } - constraints.push(Constraint::terms(layout.is_lbu(j), false, terms)); - } - // LH: sign-extend low halfword. - { - let mut terms = vec![(layout.rd_write_val(j), F::ONE)]; - for bit in 0..16 { - let coeff = if bit == 15 { - F::from_u64(pow2_u64(32) - pow2_u64(15)) - } else { - F::from_u64(pow2_u64(bit)) - }; - terms.push((layout.mem_rv_bit(bit, j), -coeff)); - } - constraints.push(Constraint::terms(layout.is_lh(j), false, terms)); - } - // LHU: zero-extend low halfword. - { - let mut terms = vec![(layout.rd_write_val(j), F::ONE)]; - for bit in 0..16 { - terms.push((layout.mem_rv_bit(bit, j), -F::from_u64(pow2_u64(bit)))); - } - constraints.push(Constraint::terms(layout.is_lhu(j), false, terms)); - } - constraints.push(Constraint::terms( - layout.is_lui(j), - false, - vec![(layout.rd_write_val(j), F::ONE), (layout.imm_u(j), -F::ONE)], - )); - - // JAL/JALR writeback: rd_write_val = pc_in + 4. - constraints.push(Constraint::terms_or( - &[layout.is_jal(j), layout.is_jalr(j)], - false, - vec![ - (layout.rd_write_val(j), F::ONE), - (pc_in, -F::ONE), - (one, -F::from_u64(4)), - ], - )); - - // PC update: - // - Straight-line instructions: pc_out = pc_in + 4. - for &f in &[ - layout.is_add(j), - layout.is_sub(j), - layout.is_sll(j), - layout.is_slt(j), - layout.is_sltu(j), - layout.is_xor(j), - layout.is_srl(j), - layout.is_sra(j), - layout.is_or(j), - layout.is_and(j), - layout.is_mul(j), - layout.is_mulh(j), - layout.is_mulhu(j), - layout.is_mulhsu(j), - layout.is_div(j), - layout.is_divu(j), - layout.is_rem(j), - layout.is_remu(j), - layout.is_addi(j), - layout.is_slti(j), - layout.is_sltiu(j), - layout.is_xori(j), - layout.is_ori(j), - layout.is_andi(j), - layout.is_slli(j), - layout.is_srli(j), - layout.is_srai(j), - layout.is_lb(j), - layout.is_lbu(j), - layout.is_lh(j), - layout.is_lhu(j), - layout.is_lw(j), - layout.is_sb(j), - layout.is_sh(j), - layout.is_sw(j), - layout.is_amoswap_w(j), - layout.is_amoadd_w(j), - layout.is_amoxor_w(j), - layout.is_amoor_w(j), - layout.is_amoand_w(j), - layout.is_lui(j), - layout.is_auipc(j), - layout.is_fence(j), - layout.is_halt(j), - ] { - constraints.push(Constraint::terms( - f, - false, - vec![(pc_out, F::ONE), (pc_in, -F::ONE), (one, -F::from_u64(4))], - )); - } - - // - JAL: pc_out = pc_in + imm_j. - constraints.push(Constraint::terms( - layout.is_jal(j), - false, - vec![(pc_out, F::ONE), (pc_in, -F::ONE), (layout.imm_j(j), -F::ONE)], - )); - - // - JALR: pc_out = (rs1 + imm_i) & !1. - // - // The ADD-table Shout output `alu_out` is (rs1 + imm_i) mod 2^32. - // Let a0,b0 be the operand LSBs (from the interleaved ADD key bits), and let a0b0 = a0*b0. - // Then lsb(alu_out) = a0 XOR b0 = a0 + b0 - 2*a0b0, and pc_out = alu_out - lsb. - constraints.push(Constraint::terms( - layout.is_jalr(j), - false, - vec![ - (pc_out, F::ONE), - (layout.alu_out(j), -F::ONE), - (add_a0, F::ONE), - (add_b0, F::ONE), - (layout.add_a0b0(j), -F::from_u64(2)), - ], - )); - - let branch_flags = [ - layout.is_beq(j), - layout.is_bne(j), - layout.is_blt(j), - layout.is_bge(j), - layout.is_bltu(j), - layout.is_bgeu(j), - ]; - - // Branch control: br_taken/br_not_taken are only set on branch rows. - constraints.push(Constraint::terms_or( - &branch_flags, - true, // (1 - is_branch) - vec![(layout.br_taken(j), F::ONE)], - )); - constraints.push(Constraint::terms_or( - &branch_flags, - true, // (1 - is_branch) - vec![(layout.br_not_taken(j), F::ONE)], - )); - - // Exactly one branch case on branch rows: br_taken + br_not_taken = is_branch. - constraints.push(Constraint::terms( - one, - false, - vec![ - (layout.br_taken(j), F::ONE), - (layout.br_not_taken(j), F::ONE), - (layout.is_beq(j), -F::ONE), - (layout.is_bne(j), -F::ONE), - (layout.is_blt(j), -F::ONE), - (layout.is_bge(j), -F::ONE), - (layout.is_bltu(j), -F::ONE), - (layout.is_bgeu(j), -F::ONE), - ], - )); - - // Branch decision comes from the Shout output: - // - BEQ/BNE/BLT/BLTU: br_taken = alu_out - // - BGE/BGEU: br_taken = 1 - alu_out - constraints.push(Constraint::terms_or( - &[layout.is_beq(j), layout.is_bne(j), layout.is_blt(j), layout.is_bltu(j)], - false, - vec![(layout.br_taken(j), F::ONE), (layout.alu_out(j), -F::ONE)], - )); - constraints.push(Constraint::terms_or( - &[layout.is_bge(j), layout.is_bgeu(j)], - false, - vec![ - (layout.br_taken(j), F::ONE), - (layout.alu_out(j), F::ONE), - (one, -F::ONE), - ], - )); - - // Branch PC update: - // - Taken: pc_out = pc_in + imm_b. - // - Not taken: pc_out = pc_in + 4. - constraints.push(Constraint::terms( - layout.br_taken(j), - false, - vec![(pc_out, F::ONE), (pc_in, -F::ONE), (layout.imm_b(j), -F::ONE)], - )); - constraints.push(Constraint::terms( - layout.br_not_taken(j), - false, - vec![(pc_out, F::ONE), (pc_in, -F::ONE), (one, -F::from_u64(4))], - )); - - // Helper: bind the product of the ADD-table operand LSBs (used for JALR mask in Phase 2). - constraints.push(Constraint::mul(add_a0, add_b0, layout.add_a0b0(j))); - - // --- Shout key correctness (ADD table bus addr bits interleaving) --- - let mut even_terms = vec![(layout.rs1_val(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = add_cols.addr_bits.start + 2 * i; - let bit = layout.bus.bus_cell(bit_col_id, j); - even_terms.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms_or( - &[ - layout.is_add(j), - layout.is_addi(j), - layout.is_lb(j), - layout.is_lbu(j), - layout.is_lh(j), - layout.is_lhu(j), - layout.is_lw(j), - layout.is_sb(j), - layout.is_sh(j), - layout.is_sw(j), - layout.is_jalr(j), - ], - false, - even_terms, - )); - - // AMOADD.W uses mem_rv as the left ADD operand (old memory value). - let mut even_terms_amoadd = vec![(layout.mem_rv(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = add_cols.addr_bits.start + 2 * i; - let bit = layout.bus.bus_cell(bit_col_id, j); - even_terms_amoadd.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(layout.is_amoadd_w(j), false, even_terms_amoadd)); - - let mut even_terms_auipc = vec![(pc_in, F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = add_cols.addr_bits.start + 2 * i; - let bit = layout.bus.bus_cell(bit_col_id, j); - even_terms_auipc.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(layout.is_auipc(j), false, even_terms_auipc)); - - let mut odd_terms_add = vec![(layout.rs2_val(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = add_cols.addr_bits.start + 2 * i + 1; - let bit = layout.bus.bus_cell(bit_col_id, j); - odd_terms_add.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms_or( - &[layout.is_add(j), layout.is_amoadd_w(j)], - false, - odd_terms_add, - )); - - let mut odd_terms_addi = vec![(layout.imm_i(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = add_cols.addr_bits.start + 2 * i + 1; - let bit = layout.bus.bus_cell(bit_col_id, j); - odd_terms_addi.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms_or( - &[ - layout.is_addi(j), - layout.is_lb(j), - layout.is_lbu(j), - layout.is_lh(j), - layout.is_lhu(j), - layout.is_lw(j), - layout.is_jalr(j), - ], - false, - odd_terms_addi, - )); - - let mut odd_terms_sw = vec![(layout.imm_s(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = add_cols.addr_bits.start + 2 * i + 1; - let bit = layout.bus.bus_cell(bit_col_id, j); - odd_terms_sw.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms_or( - &[layout.is_sb(j), layout.is_sh(j), layout.is_sw(j)], - false, - odd_terms_sw, - )); - - let mut odd_terms_auipc = vec![(layout.imm_u(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = add_cols.addr_bits.start + 2 * i + 1; - let bit = layout.bus.bus_cell(bit_col_id, j); - odd_terms_auipc.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(layout.is_auipc(j), false, odd_terms_auipc)); - - // --- Shout key correctness (EQ/NEQ table bus addr bits interleaving) --- - if let Some(eq_cols) = eq_cols { - let flag = layout.is_beq(j); - let mut even = vec![(layout.rs1_val(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = eq_cols.addr_bits.start + 2 * i; - let bit = layout.bus.bus_cell(bit_col_id, j); - even.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(flag, false, even)); - - let mut odd = vec![(layout.rs2_val(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = eq_cols.addr_bits.start + 2 * i + 1; - let bit = layout.bus.bus_cell(bit_col_id, j); - odd.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(flag, false, odd)); - } - if let Some(neq_cols) = neq_cols { - let flag = layout.is_bne(j); - let mut even = vec![(layout.rs1_val(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = neq_cols.addr_bits.start + 2 * i; - let bit = layout.bus.bus_cell(bit_col_id, j); - even.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(flag, false, even)); - - let mut odd = vec![(layout.rs2_val(j), F::ONE)]; - for i in 0..RV32_XLEN { - let bit_col_id = neq_cols.addr_bits.start + 2 * i + 1; - let bit = layout.bus.bus_cell(bit_col_id, j); - odd.push((bit, -F::from_u64(pow2_u64(i)))); - } - constraints.push(Constraint::terms(flag, false, odd)); - } - - // --- Shout key correctness (other opcode tables) --- - // AND / OR / XOR (R-type uses rs2, I-type uses imm_i). - if let Some(and_cols) = and_cols { - constraints.push(Constraint::terms_or( - &[layout.is_and(j), layout.is_andi(j)], - false, - pack_interleaved_operand(and_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_amoand_w(j), - false, - pack_interleaved_operand(and_cols.addr_bits.start, j, 0, layout.mem_rv(j)), - )); - constraints.push(Constraint::terms( - layout.is_and(j), - false, - pack_interleaved_operand(and_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_amoand_w(j), - false, - pack_interleaved_operand(and_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_andi(j), - false, - pack_interleaved_operand(and_cols.addr_bits.start, j, 1, layout.imm_i(j)), - )); - } - - if let Some(or_cols) = or_cols { - constraints.push(Constraint::terms_or( - &[layout.is_or(j), layout.is_ori(j)], - false, - pack_interleaved_operand(or_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_amoor_w(j), - false, - pack_interleaved_operand(or_cols.addr_bits.start, j, 0, layout.mem_rv(j)), - )); - constraints.push(Constraint::terms( - layout.is_or(j), - false, - pack_interleaved_operand(or_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_amoor_w(j), - false, - pack_interleaved_operand(or_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_ori(j), - false, - pack_interleaved_operand(or_cols.addr_bits.start, j, 1, layout.imm_i(j)), - )); - } - - if let Some(xor_cols) = xor_cols { - constraints.push(Constraint::terms_or( - &[layout.is_xor(j), layout.is_xori(j)], - false, - pack_interleaved_operand(xor_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_amoxor_w(j), - false, - pack_interleaved_operand(xor_cols.addr_bits.start, j, 0, layout.mem_rv(j)), - )); - constraints.push(Constraint::terms( - layout.is_xor(j), - false, - pack_interleaved_operand(xor_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_amoxor_w(j), - false, - pack_interleaved_operand(xor_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_xori(j), - false, - pack_interleaved_operand(xor_cols.addr_bits.start, j, 1, layout.imm_i(j)), - )); - } - - // SUB (R-type only). - if let Some(sub_cols) = sub_cols { - constraints.push(Constraint::terms( - layout.is_sub(j), - false, - pack_interleaved_operand(sub_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_sub(j), - false, - pack_interleaved_operand(sub_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - } - - // Shifts (R-type uses rs2, I-type uses shamt). - if let Some(sll_cols) = sll_cols { - constraints.push(Constraint::terms_or( - &[layout.is_sll(j), layout.is_slli(j)], - false, - pack_interleaved_operand(sll_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_sll(j), - false, - pack_interleaved_operand(sll_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_slli(j), - false, - pack_interleaved_operand(sll_cols.addr_bits.start, j, 1, layout.shamt(j)), - )); - } - - if let Some(srl_cols) = srl_cols { - constraints.push(Constraint::terms_or( - &[layout.is_srl(j), layout.is_srli(j)], - false, - pack_interleaved_operand(srl_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_srl(j), - false, - pack_interleaved_operand(srl_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_srli(j), - false, - pack_interleaved_operand(srl_cols.addr_bits.start, j, 1, layout.shamt(j)), - )); - } - - if let Some(sra_cols) = sra_cols { - constraints.push(Constraint::terms_or( - &[layout.is_sra(j), layout.is_srai(j)], - false, - pack_interleaved_operand(sra_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_sra(j), - false, - pack_interleaved_operand(sra_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_srai(j), - false, - pack_interleaved_operand(sra_cols.addr_bits.start, j, 1, layout.shamt(j)), - )); - } - - // SLT/SLTU (ALU + branch comparisons). - if let Some(slt_cols) = slt_cols { - constraints.push(Constraint::terms_or( - &[layout.is_slt(j), layout.is_slti(j), layout.is_blt(j), layout.is_bge(j)], - false, - pack_interleaved_operand(slt_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - constraints.push(Constraint::terms_or( - &[layout.is_slt(j), layout.is_blt(j), layout.is_bge(j)], - false, - pack_interleaved_operand(slt_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.is_slti(j), - false, - pack_interleaved_operand(slt_cols.addr_bits.start, j, 1, layout.imm_i(j)), - )); - } - - if let Some(sltu_cols) = sltu_cols { - constraints.push(Constraint::terms_or( - &[ - layout.is_sltu(j), - layout.is_sltiu(j), - layout.is_bltu(j), - layout.is_bgeu(j), - ], - false, - pack_interleaved_operand(sltu_cols.addr_bits.start, j, 0, layout.rs1_val(j)), - )); - // DIVU/REMU remainder validity check uses SLTU(rem, divisor). - constraints.push(Constraint::terms( - layout.div_rem_check(j), - false, - pack_interleaved_operand(sltu_cols.addr_bits.start, j, 0, layout.div_rem(j)), - )); - // DIV/REM remainder validity check uses SLTU(|rem|, |divisor|). - constraints.push(Constraint::terms( - layout.div_rem_check_signed(j), - false, - pack_interleaved_operand(sltu_cols.addr_bits.start, j, 0, layout.div_rem(j)), - )); - constraints.push(Constraint::terms_or( - &[layout.is_sltu(j), layout.is_bltu(j), layout.is_bgeu(j)], - false, - pack_interleaved_operand(sltu_cols.addr_bits.start, j, 1, layout.rs2_val(j)), - )); - constraints.push(Constraint::terms( - layout.div_rem_check(j), - false, - pack_interleaved_operand(sltu_cols.addr_bits.start, j, 1, layout.div_divisor(j)), - )); - constraints.push(Constraint::terms( - layout.div_rem_check_signed(j), - false, - pack_interleaved_operand(sltu_cols.addr_bits.start, j, 1, layout.div_divisor(j)), - )); - constraints.push(Constraint::terms( - layout.is_sltiu(j), - false, - pack_interleaved_operand(sltu_cols.addr_bits.start, j, 1, layout.imm_i(j)), - )); - } - - // --- Alignment constraints (MVP) --- - // ROM fetch is always 32-bit, so enforce pc_in % 4 == 0 via PROG read address bits. - let prog_bit0 = layout.bus.bus_cell(prog.ra_bits.start + 0, j); - let prog_bit1 = layout.bus.bus_cell(prog.ra_bits.start + 1, j); - constraints.push(Constraint::zero(one, prog_bit0)); - constraints.push(Constraint::zero(one, prog_bit1)); - - // Enforce alignment for half/word accesses via RAM bus addr bits. - let ra0 = layout.bus.bus_cell(ram.ra_bits.start + 0, j); - let ra1 = layout.bus.bus_cell(ram.ra_bits.start + 1, j); - let wa0 = layout.bus.bus_cell(ram.wa_bits.start + 0, j); - let wa1 = layout.bus.bus_cell(ram.wa_bits.start + 1, j); - let amo_flags = [ - layout.is_amoswap_w(j), - layout.is_amoadd_w(j), - layout.is_amoxor_w(j), - layout.is_amoor_w(j), - layout.is_amoand_w(j), - ]; - constraints.push(Constraint::terms_or( - &[ - layout.is_lh(j), - layout.is_lhu(j), - layout.is_lw(j), - layout.is_sh(j), - amo_flags[0], - amo_flags[1], - amo_flags[2], - amo_flags[3], - amo_flags[4], - ], - false, - vec![(ra0, F::ONE)], - )); - constraints.push(Constraint::terms_or( - &[ - layout.is_lw(j), - amo_flags[0], - amo_flags[1], - amo_flags[2], - amo_flags[3], - amo_flags[4], - ], - false, - vec![(ra1, F::ONE)], - )); - constraints.push(Constraint::terms_or( - &[ - layout.is_sh(j), - layout.is_sw(j), - amo_flags[0], - amo_flags[1], - amo_flags[2], - amo_flags[3], - amo_flags[4], - ], - false, - vec![(wa0, F::ONE)], - )); - constraints.push(Constraint::terms_or( - &[ - layout.is_sw(j), - amo_flags[0], - amo_flags[1], - amo_flags[2], - amo_flags[3], - amo_flags[4], - ], - false, - vec![(wa1, F::ONE)], - )); - } - - // --- Intra-chunk composition / padding semantics --- - // Enforce monotone activity and state continuity: - // - is_active[j+1] => is_active[j] - // - pc_in[j+1] == pc_out[j] and regs_in[j+1] == regs_out[j] for all j - // - // The unconditional continuity ensures padding rows (is_active=0) *carry* the final - // architectural state forward, making the final state unambiguous in an L1-style layout. - for j in 0..layout.chunk_size.saturating_sub(1) { - let a = layout.is_active(j); - let b = layout.is_active(j + 1); - - // b * (1 - a) = 0 - constraints.push(Constraint::terms(b, false, vec![(one, F::ONE), (a, -F::ONE)])); - - // HALT terminates execution within a chunk: halt_effective[j] => is_active[j+1] == 0. - constraints.push(Constraint::terms( - layout.halt_effective(j), - false, - vec![(layout.is_active(j + 1), F::ONE)], - )); - - // pc_in[j+1] - pc_out[j] = 0 - constraints.push(Constraint::terms( - one, - false, - vec![(layout.pc_in(j + 1), F::ONE), (layout.pc_out(j), -F::ONE)], - )); - - // regs_in[j+1] - regs_out[j] = 0 for all regs - for r in 0..32 { - constraints.push(Constraint::terms( - one, - false, - vec![(layout.reg_in(r, j + 1), F::ONE), (layout.reg_out(r, j), -F::ONE)], - )); - } - } - - Ok(constraints) -} - -/// Build the RV32 B1 step CCS and its witness layout. -/// -/// Requirements: -/// - `mem_layouts` must include `RAM_ID` and `PROG_ID`. -/// - `mem_layouts[PROG_ID]` is byte-addressed (`n_side=2`, `ell=1`). -/// -/// `shout_table_ids` must be non-empty and include the RV32 `ADD` table id (3). Any subset of the -/// base RV32I opcode tables (ids 0..=11) is allowed (unused tables can be bound to fixed-zero CPU -/// columns in the shared-bus config). -pub fn build_rv32_b1_step_ccs( - mem_layouts: &HashMap, - shout_table_ids: &[u32], - chunk_size: usize, -) -> Result<(CcsStructure, Rv32B1Layout), String> { - if chunk_size == 0 { - return Err("RV32 B1: chunk_size must be >= 1".into()); - } - let ram_id = RAM_ID.0; - let prog_id = PROG_ID.0; - if !mem_layouts.contains_key(&ram_id) { - return Err(format!("RV32 B1: mem_layouts missing RAM_ID={ram_id}")); - } - if !mem_layouts.contains_key(&prog_id) { - return Err(format!("RV32 B1: mem_layouts missing PROG_ID={prog_id}")); - } - - // B1 circuit currently assumes only RISC-V opcode Shout tables (ell_addr = 2*xlen = 64). - let (table_ids, shout_ell_addrs) = derive_shout_ids_and_ell_addrs(shout_table_ids)?; - - let (mem_ids, twist_ell_addrs) = derive_mem_ids_and_ell_addrs(mem_layouts)?; - if mem_ids.len() != twist_ell_addrs.len() { - return Err("RV32 B1: internal error (twist ell addrs mismatch)".into()); - } - let bus_cols_per_step: usize = shout_ell_addrs.iter().sum::() - + 2 * shout_ell_addrs.len() - + twist_ell_addrs - .iter() - .map(|&ell_addr| 2 * ell_addr + 5) - .sum::(); - let bus_region_len = bus_cols_per_step - .checked_mul(chunk_size) - .ok_or_else(|| "RV32 B1: bus_region_len overflow".to_string())?; - - // Probe layout to learn the CPU column footprint and count injected constraints. - // We rebuild once with the minimal `m` after computing exact sizes. - let mut probe_m = bus_region_len - .checked_add(1) - .ok_or_else(|| "RV32 B1: probe_m overflow".to_string())?; - let probe = loop { - match build_layout_with_m(probe_m, mem_layouts, &table_ids, chunk_size) { - Ok(layout) => break layout, - Err(e) - if e.contains("need more padding columns before bus tail") || e.contains("overlaps public inputs") => - { - probe_m = probe_m - .checked_mul(2) - .ok_or_else(|| "RV32 B1: probe_m overflow".to_string())?; - } - Err(e) => return Err(e), - } - }; - let cpu_cols_used = probe.halt_effective + chunk_size; - - let injected = injected_bus_constraints_len(&probe, &table_ids, &mem_ids); - - let m_cols_min = cpu_cols_used + bus_region_len; - - let mut m = m_cols_min; - let layout = loop { - match build_layout_with_m(m, mem_layouts, &table_ids, chunk_size) { - Ok(layout) => break layout, - Err(e) - if e.contains("need more padding columns before bus tail") || e.contains("overlaps public inputs") => - { - m = m - .checked_mul(2) - .ok_or_else(|| "RV32 B1: m overflow".to_string())?; - } - Err(e) => return Err(e), - } - }; - let constraints = semantic_constraints(&layout, mem_layouts)?; - let n = constraints - .len() - .checked_add(injected) - .ok_or_else(|| "RV32 B1: n overflow".to_string())?; - let ccs = build_r1cs_ccs(&constraints, n, layout.m, layout.const_one)?; - Ok((ccs, layout)) -} diff --git a/crates/neo-memory/src/riscv/ccs/bus_bindings.rs b/crates/neo-memory/src/riscv/ccs/bus_bindings.rs index 3bd0a8f0..28f60381 100644 --- a/crates/neo-memory/src/riscv/ccs/bus_bindings.rs +++ b/crates/neo-memory/src/riscv/ccs/bus_bindings.rs @@ -1,184 +1,436 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use p3_goldilocks::Goldilocks as F; -use crate::cpu::constraints::{CpuConstraintBuilder, ShoutCpuBinding, TwistCpuBinding}; +use crate::cpu::bus_layout::{ + build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes, BusLayout, ShoutInstanceShape, +}; +use crate::cpu::constraints::{ + CpuConstraint, CpuConstraintBuilder, ShoutCpuBinding, TwistCpuBinding, CPU_BUS_COL_DISABLED, +}; use crate::cpu::r1cs_adapter::SharedCpuBusConfig; use crate::plain::PlainMemLayout; -use crate::riscv::lookups::{PROG_ID, RAM_ID}; +use crate::riscv::lookups::{PROG_ID, RAM_ID, REG_ID}; +use crate::riscv::trace::{ + rv32_decode_lookup_table_id_for_col, rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id, + rv32_trace_lookup_addr_group_for_table_id, rv32_trace_lookup_selector_group_for_table_id, Rv32DecodeSidecarLayout, +}; -use super::config::{derive_mem_ids_and_ell_addrs, derive_shout_ids_and_ell_addrs}; use super::constants::{ - ADD_TABLE_ID, AND_TABLE_ID, EQ_TABLE_ID, NEQ_TABLE_ID, OR_TABLE_ID, SLL_TABLE_ID, SLTU_TABLE_ID, SLT_TABLE_ID, - SRA_TABLE_ID, SRL_TABLE_ID, SUB_TABLE_ID, XOR_TABLE_ID, + ADD_TABLE_ID, AND_TABLE_ID, DIVU_TABLE_ID, DIV_TABLE_ID, EQ_TABLE_ID, MULHSU_TABLE_ID, MULHU_TABLE_ID, + MULH_TABLE_ID, MUL_TABLE_ID, NEQ_TABLE_ID, OR_TABLE_ID, REMU_TABLE_ID, REM_TABLE_ID, RV32_XLEN, SLL_TABLE_ID, + SLTU_TABLE_ID, SLT_TABLE_ID, SRA_TABLE_ID, SRL_TABLE_ID, SUB_TABLE_ID, XOR_TABLE_ID, }; -use super::Rv32B1Layout; +use super::trace::Rv32TraceCcsLayout; + +/// Additional trace-mode Shout lookup family specification. +/// +/// This allows callers to provision lookup families beyond the fixed RV32 opcode +/// tables, with table-specific address widths (`ell_addr`) and value count. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct TraceShoutBusSpec { + pub table_id: u32, + pub ell_addr: usize, + pub n_vals: usize, +} + +#[inline] +fn trace_cpu_col(layout: &Rv32TraceCcsLayout, trace_col: usize) -> usize { + layout.cell(trace_col, 0) +} + +#[inline] +fn trace_shout_binding(layout: &Rv32TraceCcsLayout, table_id: u32) -> Option { + if rv32_is_decode_lookup_table_id(table_id) { + // Decode lookup families are keyed by PROG read address (pc_before). + Some(ShoutCpuBinding { + has_lookup: CPU_BUS_COL_DISABLED, + addr: Some(trace_cpu_col(layout, layout.trace.pc_before)), + val: CPU_BUS_COL_DISABLED, + }) + } else if rv32_is_width_lookup_table_id(table_id) { + // Width helper lookup families are keyed by cycle index. + Some(ShoutCpuBinding { + has_lookup: CPU_BUS_COL_DISABLED, + addr: Some(trace_cpu_col(layout, layout.trace.cycle)), + val: CPU_BUS_COL_DISABLED, + }) + } else { + None + } +} -fn shout_cpu_binding(layout: &Rv32B1Layout, table_id: u32) -> ShoutCpuBinding { +#[inline] +fn validate_trace_shout_table_id(table_id: u32) -> Result<(), String> { match table_id { - AND_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.and_has_lookup, - addr: None, - val: layout.alu_out, - }, - XOR_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.xor_has_lookup, - addr: None, - val: layout.alu_out, - }, - OR_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.or_has_lookup, - addr: None, - val: layout.alu_out, - }, - ADD_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.add_has_lookup, - addr: None, - val: layout.alu_out, - }, - SUB_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.is_sub, - addr: None, - val: layout.alu_out, - }, - SLT_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.slt_has_lookup, - addr: None, - val: layout.alu_out, - }, - SLTU_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.sltu_has_lookup, - addr: None, - val: layout.alu_out, - }, - SLL_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.sll_has_lookup, - addr: None, - val: layout.alu_out, - }, - SRL_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.srl_has_lookup, - addr: None, - val: layout.alu_out, - }, - SRA_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.sra_has_lookup, - addr: None, - val: layout.alu_out, - }, - EQ_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.is_beq, - addr: None, - val: layout.alu_out, - }, - NEQ_TABLE_ID => ShoutCpuBinding { - has_lookup: layout.is_bne, - addr: None, - val: layout.alu_out, - }, - _ => { - // Bind unused tables to fixed-zero CPU columns so they are provably inactive. - let zero = layout.reg_in(0, 0); - ShoutCpuBinding { - has_lookup: zero, - addr: None, - val: zero, - } - } - } -} - -fn twist_cpu_binding(layout: &Rv32B1Layout, mem_id: u32) -> TwistCpuBinding { + AND_TABLE_ID | XOR_TABLE_ID | OR_TABLE_ID | ADD_TABLE_ID | SUB_TABLE_ID | SLT_TABLE_ID | SLTU_TABLE_ID + | SLL_TABLE_ID | SRL_TABLE_ID | SRA_TABLE_ID | EQ_TABLE_ID | NEQ_TABLE_ID | MUL_TABLE_ID | MULH_TABLE_ID + | MULHU_TABLE_ID | MULHSU_TABLE_ID | DIV_TABLE_ID | DIVU_TABLE_ID | REM_TABLE_ID | REMU_TABLE_ID => Ok(()), + _ => Err(format!("RV32 trace shared bus: unsupported shout table_id={table_id}")), + } +} + +#[inline] +fn trace_lookup_addr_group_for_table_id(table_id: u32) -> Option { + rv32_trace_lookup_addr_group_for_table_id(table_id) +} + +#[inline] +fn trace_lookup_selector_group_for_table_id(table_id: u32) -> Option { + rv32_trace_lookup_selector_group_for_table_id(table_id) +} + +#[derive(Clone, Copy, Debug)] +struct TraceShoutShape { + table_id: u32, + ell_addr: usize, + n_vals: usize, + addr_group: Option, + selector_group: Option, +} + +fn derive_trace_shout_shapes( + shout_table_ids: &[u32], + extra_shout_specs: &[TraceShoutBusSpec], +) -> Result, String> { + let mut shape_by_table_id = HashMap::::new(); + + for &table_id in shout_table_ids { + validate_trace_shout_table_id(table_id)?; + shape_by_table_id.insert( + table_id, + TraceShoutShape { + table_id, + ell_addr: 2 * RV32_XLEN, + n_vals: 1usize, + addr_group: trace_lookup_addr_group_for_table_id(table_id), + selector_group: trace_lookup_selector_group_for_table_id(table_id), + }, + ); + } + + for spec in extra_shout_specs { + if spec.ell_addr == 0 { + return Err(format!( + "RV32 trace shared bus: extra shout spec for table_id={} has ell_addr=0", + spec.table_id + )); + } + if spec.n_vals == 0 { + return Err(format!( + "RV32 trace shared bus: extra shout spec for table_id={} has n_vals=0", + spec.table_id + )); + } + if let Some(prev) = shape_by_table_id.get(&spec.table_id) { + if prev.ell_addr != spec.ell_addr { + return Err(format!( + "RV32 trace shared bus: conflicting ell_addr for table_id={} (base/spec mismatch: {} vs {})", + spec.table_id, prev.ell_addr, spec.ell_addr + )); + } + if prev.n_vals != spec.n_vals { + return Err(format!( + "RV32 trace shared bus: conflicting n_vals for table_id={} (base/spec mismatch: {} vs {})", + spec.table_id, prev.n_vals, spec.n_vals + )); + } + let inferred_group = trace_lookup_addr_group_for_table_id(spec.table_id); + if prev.addr_group != inferred_group { + return Err(format!( + "RV32 trace shared bus: conflicting addr_group for table_id={} (base/spec mismatch: {:?} vs {:?})", + spec.table_id, prev.addr_group, inferred_group + )); + } + let inferred_selector_group = trace_lookup_selector_group_for_table_id(spec.table_id); + if prev.selector_group != inferred_selector_group { + return Err(format!( + "RV32 trace shared bus: conflicting selector_group for table_id={} (base/spec mismatch: {:?} vs {:?})", + spec.table_id, prev.selector_group, inferred_selector_group + )); + } + } else { + shape_by_table_id.insert( + spec.table_id, + TraceShoutShape { + table_id: spec.table_id, + ell_addr: spec.ell_addr, + n_vals: spec.n_vals, + addr_group: trace_lookup_addr_group_for_table_id(spec.table_id), + selector_group: trace_lookup_selector_group_for_table_id(spec.table_id), + }, + ); + } + } + + let mut shapes: Vec = shape_by_table_id.into_values().collect(); + shapes.sort_unstable_by_key(|shape| shape.table_id); + Ok(shapes) +} + +fn audit_bus_tail_constraint_coverage( + builder: &CpuConstraintBuilder, + bus: &crate::cpu::bus_layout::BusLayout, +) -> Result<(), String> { + let mut referenced = vec![false; bus.bus_cols]; + let bus_end = bus + .bus_base + .checked_add(bus.bus_region_len()) + .ok_or_else(|| "RV32 trace shared bus: bus tail end overflow during coverage audit".to_string())?; + + let mut mark_col = |col: usize| { + if col >= bus.bus_base && col < bus_end { + let rel = col - bus.bus_base; + let col_id = rel / bus.chunk_size; + if col_id < referenced.len() { + referenced[col_id] = true; + } + } + }; + + for c in builder.constraints() { + mark_col(c.condition_col); + for &col in &c.additional_condition_cols { + mark_col(col); + } + for &(col, _) in &c.b_terms { + mark_col(col); + } + } + + let dead: Vec = referenced + .iter() + .enumerate() + .filter_map(|(i, used)| if *used { None } else { Some(i) }) + .collect(); + + if dead.is_empty() { + return Ok(()); + } + + let preview: Vec = dead.iter().copied().take(24).collect(); + Err(format!( + "RV32 trace shared bus: dead bus-tail columns are not referenced by constraints (count={}, first={preview:?})", + dead.len() + )) +} + +#[inline] +fn trace_disabled_twist_binding(_layout: &Rv32TraceCcsLayout) -> TwistCpuBinding { + TwistCpuBinding { + has_read: CPU_BUS_COL_DISABLED, + has_write: CPU_BUS_COL_DISABLED, + read_addr: CPU_BUS_COL_DISABLED, + write_addr: CPU_BUS_COL_DISABLED, + rv: CPU_BUS_COL_DISABLED, + wv: CPU_BUS_COL_DISABLED, + inc: None, + } +} + +#[derive(Clone, Copy, Debug)] +struct TraceDecodeSelectorCols { + rd_has_write: usize, + ram_has_read: usize, + ram_has_write: usize, +} + +fn resolve_trace_decode_selector_cols( + layout: &Rv32TraceCcsLayout, + shout_shapes: &[TraceShoutShape], + mem_layouts: &HashMap, +) -> Result { + let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); + mem_ids.sort_unstable(); + let mut twist_shapes = Vec::with_capacity(mem_ids.len()); + for mem_id in &mem_ids { + let mem_layout = mem_layouts + .get(mem_id) + .ok_or_else(|| format!("RV32 trace shared bus: missing mem layout for mem_id={mem_id}"))?; + if mem_layout.n_side == 0 || !mem_layout.n_side.is_power_of_two() { + return Err(format!( + "RV32 trace shared bus: mem_id={mem_id} n_side={} must be power-of-two", + mem_layout.n_side + )); + } + let ell = mem_layout.n_side.trailing_zeros() as usize; + let ell_addr = mem_layout.d * ell; + twist_shapes.push((ell_addr, mem_layout.lanes.max(1))); + } + + let bus = build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes( + layout.m, + layout.m_in, + layout.t, + shout_shapes.iter().map(|shape| ShoutInstanceShape { + ell_addr: shape.ell_addr, + lanes: 1usize, + n_vals: shape.n_vals.max(1), + addr_group: shape.addr_group.map(|v| v as u64), + selector_group: shape.selector_group.map(|v| v as u64), + }), + twist_shapes.iter().copied(), + )?; + + trace_decode_selector_cols_from_bus(&bus, shout_shapes) +} + +fn trace_decode_selector_cols_from_bus( + bus: &crate::cpu::bus_layout::BusLayout, + shout_shapes: &[TraceShoutShape], +) -> Result { + let decode_layout = Rv32DecodeSidecarLayout::new(); + let rd_has_write_table_id = rv32_decode_lookup_table_id_for_col(decode_layout.rd_has_write); + let ram_has_read_table_id = rv32_decode_lookup_table_id_for_col(decode_layout.ram_has_read); + let ram_has_write_table_id = rv32_decode_lookup_table_id_for_col(decode_layout.ram_has_write); + let table_val_col = |table_id: u32| -> Result { + let shout_idx = shout_shapes + .iter() + .position(|shape| shape.table_id == table_id) + .ok_or_else(|| { + format!( + "RV32 trace shared bus: missing decode lookup table_id={table_id} required for Twist selector binding" + ) + })?; + let inst_cols = bus.shout_cols.get(shout_idx).ok_or_else(|| { + format!("RV32 trace shared bus: missing shout cols for decode lookup table_id={table_id}") + })?; + let lane0 = inst_cols.lanes.first().ok_or_else(|| { + format!("RV32 trace shared bus: expected one shout lane for decode lookup table_id={table_id}") + })?; + bus.bus_base + .checked_add(lane0.primary_val() * bus.chunk_size) + .ok_or_else(|| "RV32 trace shared bus: decode selector column overflow".to_string()) + }; + Ok(TraceDecodeSelectorCols { + rd_has_write: table_val_col(rd_has_write_table_id)?, + ram_has_read: table_val_col(ram_has_read_table_id)?, + ram_has_write: table_val_col(ram_has_write_table_id)?, + }) +} + +#[inline] +fn trace_twist_primary_binding( + layout: &Rv32TraceCcsLayout, + mem_id: u32, + decode_selectors: TraceDecodeSelectorCols, +) -> TwistCpuBinding { + let active = trace_cpu_col(layout, layout.trace.active); if mem_id == RAM_ID.0 { TwistCpuBinding { - has_read: layout.ram_has_read, - has_write: layout.ram_has_write, - read_addr: layout.eff_addr, - write_addr: layout.eff_addr, - rv: layout.mem_rv, - wv: layout.ram_wv, + has_read: decode_selectors.ram_has_read, + has_write: decode_selectors.ram_has_write, + read_addr: trace_cpu_col(layout, layout.trace.ram_addr), + write_addr: trace_cpu_col(layout, layout.trace.ram_addr), + rv: trace_cpu_col(layout, layout.trace.ram_rv), + wv: trace_cpu_col(layout, layout.trace.ram_wv), inc: None, } } else if mem_id == PROG_ID.0 { - let zero = layout.reg_in(0, 0); TwistCpuBinding { - has_read: layout.is_active, - has_write: zero, - read_addr: layout.pc_in, - write_addr: zero, - rv: layout.instr_word, - wv: zero, + has_read: active, + has_write: CPU_BUS_COL_DISABLED, + read_addr: trace_cpu_col(layout, layout.trace.pc_before), + write_addr: CPU_BUS_COL_DISABLED, + rv: trace_cpu_col(layout, layout.trace.instr_word), + wv: CPU_BUS_COL_DISABLED, inc: None, } - } else { - // Disable any additional Twist instances by binding to fixed-zero CPU columns. - let zero = layout.reg_in(0, 0); + } else if mem_id == REG_ID.0 { TwistCpuBinding { - has_read: zero, - has_write: zero, - read_addr: zero, - write_addr: zero, - rv: zero, - wv: zero, + has_read: active, + has_write: decode_selectors.rd_has_write, + read_addr: trace_cpu_col(layout, layout.trace.rs1_addr), + write_addr: trace_cpu_col(layout, layout.trace.rd_addr), + rv: trace_cpu_col(layout, layout.trace.rs1_val), + wv: trace_cpu_col(layout, layout.trace.rd_val), inc: None, } + } else { + trace_disabled_twist_binding(layout) } } -pub(super) fn injected_bus_constraints_len(layout: &Rv32B1Layout, table_ids: &[u32], mem_ids: &[u32]) -> usize { - let shout_cpu: Vec = table_ids - .iter() - .map(|&id| shout_cpu_binding(layout, id)) - .collect(); - let twist_cpu: Vec = mem_ids - .iter() - .map(|&id| twist_cpu_binding(layout, id)) - .collect(); - - let mut builder = CpuConstraintBuilder::::new(layout.m, layout.m, layout.const_one); - for (i, cpu) in shout_cpu.iter().enumerate() { - builder.add_shout_instance_bound(&layout.bus, &layout.bus.shout_cols[i].lanes[0], cpu); - } - for (i, cpu) in twist_cpu.iter().enumerate() { - builder.add_twist_instance_bound(&layout.bus, &layout.bus.twist_cols[i].lanes[0], cpu); - } - builder.constraints().len() +/// Shared CPU-bus bindings for the RV32 trace-wiring step circuit. +pub fn rv32_trace_shared_cpu_bus_config( + layout: &Rv32TraceCcsLayout, + shout_table_ids: &[u32], + mem_layouts: HashMap, + initial_mem: HashMap<(u32, u64), F>, +) -> Result, String> { + rv32_trace_shared_cpu_bus_config_with_specs(layout, shout_table_ids, &[], mem_layouts, initial_mem) } -/// Shared CPU-bus bindings for the RV32 B1 step circuit. -/// -/// This config: -/// - binds `PROG_ID` reads to `pc_in` / `instr_word`, forces no ROM writes, -/// - binds `RAM_ID` reads/writes to `eff_addr` / `mem_rv` / `ram_wv` (with selectors derived from instruction flags), -/// - binds RV32IM Shout opcode tables (ids 0..=19) to `alu_out` (addr_bits are constrained directly by the step CCS). -pub fn rv32_b1_shared_cpu_bus_config( - layout: &Rv32B1Layout, +/// Shared CPU-bus bindings for trace mode with extra lookup-family Shout specs. +pub fn rv32_trace_shared_cpu_bus_config_with_specs( + layout: &Rv32TraceCcsLayout, shout_table_ids: &[u32], + extra_shout_specs: &[TraceShoutBusSpec], mem_layouts: HashMap, initial_mem: HashMap<(u32, u64), F>, ) -> Result, String> { - let (table_ids, _ell_addrs) = derive_shout_ids_and_ell_addrs(shout_table_ids)?; + let shout_shapes = derive_trace_shout_shapes(shout_table_ids, extra_shout_specs)?; + let decode_selectors = resolve_trace_decode_selector_cols(layout, &shout_shapes, &mem_layouts)?; let mut shout_cpu = HashMap::new(); - for table_id in table_ids { - shout_cpu.insert(table_id, vec![shout_cpu_binding(layout, table_id)]); + for shape in &shout_shapes { + // Opcode lookup families remain reduction-owned; decode/width families bind key columns. + let binding = trace_shout_binding(layout, shape.table_id); + shout_cpu.insert(shape.table_id, binding.into_iter().collect()); } - let (mem_ids, _ell_addrs) = derive_mem_ids_and_ell_addrs(&mem_layouts)?; + let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); + mem_ids.sort_unstable(); let mut twist_cpu = HashMap::new(); for mem_id in mem_ids { let lanes = mem_layouts .get(&mem_id) .map(|l| l.lanes.max(1)) - .unwrap_or(1); - let primary = twist_cpu_binding(layout, mem_id); - let disabled = twist_cpu_binding(layout, u32::MAX); - let mut bindings = Vec::with_capacity(lanes); - bindings.push(primary); - for _ in 1..lanes { - bindings.push(disabled.clone()); + .ok_or_else(|| format!("RV32 trace shared bus: missing mem layout for mem_id={mem_id}"))?; + + if mem_id == REG_ID.0 { + if lanes < 2 { + return Err(format!( + "RV32 trace shared bus: REG_ID requires lanes>=2 (got lanes={lanes})" + )); + } + let mut bindings = Vec::with_capacity(lanes); + bindings.push(trace_twist_primary_binding(layout, mem_id, decode_selectors)); + bindings.push(TwistCpuBinding { + has_read: trace_cpu_col(layout, layout.trace.active), + has_write: CPU_BUS_COL_DISABLED, + read_addr: trace_cpu_col(layout, layout.trace.rs2_addr), + write_addr: CPU_BUS_COL_DISABLED, + rv: trace_cpu_col(layout, layout.trace.rs2_val), + wv: CPU_BUS_COL_DISABLED, + inc: None, + }); + let disabled = trace_disabled_twist_binding(layout); + for _ in 2..lanes { + bindings.push(disabled.clone()); + } + twist_cpu.insert(mem_id, bindings); + } else { + let primary = trace_twist_primary_binding(layout, mem_id, decode_selectors); + let disabled = trace_disabled_twist_binding(layout); + let mut bindings = Vec::with_capacity(lanes); + bindings.push(primary); + for _ in 1..lanes { + bindings.push(disabled.clone()); + } + twist_cpu.insert(mem_id, bindings); + } + } + + let mut shout_addr_groups = HashMap::new(); + let mut shout_selector_groups = HashMap::new(); + for shape in &shout_shapes { + if let Some(g) = shape.addr_group { + shout_addr_groups.insert(shape.table_id, g as u64); + } + if let Some(g) = shape.selector_group { + shout_selector_groups.insert(shape.table_id, g as u64); } - twist_cpu.insert(mem_id, bindings); } Ok(SharedCpuBusConfig { @@ -187,5 +439,253 @@ pub fn rv32_b1_shared_cpu_bus_config( const_one_col: layout.const_one, shout_cpu, twist_cpu, + shout_addr_groups, + shout_selector_groups, + }) +} + +/// Return `(bus_region_len, reserved_rows)` required by trace shared-bus mode. +pub fn rv32_trace_shared_bus_requirements( + layout: &Rv32TraceCcsLayout, + shout_table_ids: &[u32], + mem_layouts: &HashMap, +) -> Result<(usize, usize), String> { + rv32_trace_shared_bus_requirements_with_specs(layout, shout_table_ids, &[], mem_layouts) +} + +/// Return `(bus_region_len, reserved_rows)` required by trace shared-bus mode with extra lookup-family specs. +pub fn rv32_trace_shared_bus_requirements_with_specs( + layout: &Rv32TraceCcsLayout, + shout_table_ids: &[u32], + extra_shout_specs: &[TraceShoutBusSpec], + mem_layouts: &HashMap, +) -> Result<(usize, usize), String> { + let snapshot = + rv32_trace_shared_bus_extraction_with_specs(layout, shout_table_ids, extra_shout_specs, mem_layouts)?; + Ok((snapshot.bus.bus_region_len(), snapshot.constraints.len())) +} + +/// Debug/extractor snapshot for trace shared-bus planning. +/// +/// This exposes the exact bus layout and injected CPU-bus constraints that would be +/// synthesized for the given configuration. It is intended for tooling (e.g. Lean +/// extraction), not for proving/verification hot paths. +#[derive(Clone, Debug)] +pub struct TraceSharedBusExtraction { + pub bus: BusLayout, + pub constraints: Vec>, +} + +/// Return the shared-bus layout and constraint list required by trace mode. +pub fn rv32_trace_shared_bus_extraction( + layout: &Rv32TraceCcsLayout, + shout_table_ids: &[u32], + mem_layouts: &HashMap, +) -> Result { + rv32_trace_shared_bus_extraction_with_specs(layout, shout_table_ids, &[], mem_layouts) +} + +/// Return the shared-bus layout and constraint list required by trace mode with extra lookup-family specs. +pub fn rv32_trace_shared_bus_extraction_with_specs( + layout: &Rv32TraceCcsLayout, + shout_table_ids: &[u32], + extra_shout_specs: &[TraceShoutBusSpec], + mem_layouts: &HashMap, +) -> Result { + let shout_shapes = derive_trace_shout_shapes(shout_table_ids, extra_shout_specs)?; + + let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); + mem_ids.sort_unstable(); + + let mut shout_cols = 0usize; + let mut seen_addr_groups = HashMap::::new(); + let mut seen_selector_groups = HashSet::::new(); + for shape in &shout_shapes { + if let Some(group) = shape.addr_group { + if let Some(prev_ell) = seen_addr_groups.insert(group, shape.ell_addr) { + if prev_ell != shape.ell_addr { + return Err(format!( + "RV32 trace shared bus: addr_group={} has conflicting ell_addr ({} vs {})", + group, prev_ell, shape.ell_addr + )); + } + } else { + shout_cols = shout_cols + .checked_add(shape.ell_addr) + .ok_or_else(|| "RV32 trace shared bus: shout shared-addr width overflow".to_string())?; + } + } else { + shout_cols = shout_cols + .checked_add(shape.ell_addr) + .ok_or_else(|| "RV32 trace shared bus: shout lane width overflow".to_string())?; + } + if let Some(selector_group) = shape.selector_group { + if seen_selector_groups.insert(selector_group) { + shout_cols = shout_cols + .checked_add(1) + .ok_or_else(|| "RV32 trace shared bus: shout selector width overflow".to_string())?; + } + } else { + shout_cols = shout_cols + .checked_add(1) + .ok_or_else(|| "RV32 trace shared bus: shout selector width overflow".to_string())?; + } + shout_cols = shout_cols + .checked_add(shape.n_vals) + .ok_or_else(|| "RV32 trace shared bus: shout value width overflow".to_string())?; + } + + let mut twist_cols = 0usize; + let mut twist_shapes = Vec::with_capacity(mem_ids.len()); + for mem_id in &mem_ids { + let mem_layout = mem_layouts + .get(mem_id) + .ok_or_else(|| format!("RV32 trace shared bus: missing mem layout for mem_id={mem_id}"))?; + if mem_layout.n_side == 0 || !mem_layout.n_side.is_power_of_two() { + return Err(format!( + "RV32 trace shared bus: mem_id={mem_id} n_side={} must be power-of-two", + mem_layout.n_side + )); + } + let ell = mem_layout.n_side.trailing_zeros() as usize; + let ell_addr = mem_layout.d * ell; + let lanes = mem_layout.lanes.max(1); + if *mem_id == REG_ID.0 && lanes < 2 { + return Err(format!( + "RV32 trace shared bus: REG_ID requires lanes>=2 (got lanes={lanes})" + )); + } + twist_cols = twist_cols + .checked_add((2 * ell_addr + 5) * lanes) + .ok_or_else(|| "RV32 trace shared bus: twist bus column overflow".to_string())?; + twist_shapes.push((ell_addr, lanes)); + } + + let bus_cols = shout_cols + .checked_add(twist_cols) + .ok_or_else(|| "RV32 trace shared bus: bus column overflow".to_string())?; + let bus_region_len = bus_cols + .checked_mul(layout.t) + .ok_or_else(|| "RV32 trace shared bus: bus region overflow".to_string())?; + let m_total = layout + .m + .checked_add(bus_region_len) + .ok_or_else(|| "RV32 trace shared bus: total m overflow".to_string())?; + + let bus = build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes( + m_total, + layout.m_in, + layout.t, + shout_shapes.iter().map(|shape| ShoutInstanceShape { + ell_addr: shape.ell_addr, + lanes: 1usize, + n_vals: shape.n_vals.max(1), + addr_group: shape.addr_group.map(|v| v as u64), + selector_group: shape.selector_group.map(|v| v as u64), + }), + twist_shapes.iter().copied(), + )?; + let decode_selectors = trace_decode_selector_cols_from_bus(&bus, &shout_shapes)?; + + let mut builder = CpuConstraintBuilder::::new(m_total, m_total, layout.const_one); + + let mut addr_range_counts = HashMap::<(usize, usize), usize>::new(); + for inst_cols in &bus.shout_cols { + for lane_cols in &inst_cols.lanes { + let key = (lane_cols.addr_bits.start, lane_cols.addr_bits.end); + *addr_range_counts.entry(key).or_insert(0) += 1; + } + } + + let mut addr_range_bitness_added = HashSet::<(usize, usize)>::new(); + let mut selector_bitness_added = HashSet::::new(); + let mut shout_key_binding_added = HashSet::<(bool, usize, usize, usize, usize)>::new(); + for (i, shape) in shout_shapes.iter().enumerate() { + let lane0 = &bus.shout_cols[i].lanes[0]; + if let Some(binding) = trace_shout_binding(layout, shape.table_id) { + let mut dedup_binding = binding.clone(); + if let Some(addr_base) = dedup_binding.addr { + let (is_bus_gate, gate_base) = if dedup_binding.has_lookup == CPU_BUS_COL_DISABLED { + (true, lane0.has_lookup) + } else { + (false, dedup_binding.has_lookup) + }; + let key_sig = ( + is_bus_gate, + gate_base, + addr_base, + lane0.addr_bits.start, + lane0.addr_bits.end, + ); + if !shout_key_binding_added.insert(key_sig) { + dedup_binding.addr = None; + } + } + builder.add_shout_instance_linkage_bound(&bus, lane0, &dedup_binding); + } + + let key = (lane0.addr_bits.start, lane0.addr_bits.end); + let shared_addr_group = addr_range_counts.get(&key).copied().unwrap_or(0) > 1; + let selector_first = selector_bitness_added.insert(lane0.has_lookup); + if shared_addr_group { + if selector_first { + builder.add_shout_instance_padding_value_only(&bus, lane0); + } else { + builder.add_shout_instance_value_padding_only(&bus, lane0); + } + if addr_range_bitness_added.insert(key) { + builder.add_shout_instance_addr_bit_bitness(&bus, lane0); + } + } else if selector_first { + builder.add_shout_instance_padding(&bus, lane0); + } else { + builder.add_shout_instance_padding_without_selector_bitness(&bus, lane0); + } + } + + for (i, &mem_id) in mem_ids.iter().enumerate() { + let inst = &bus.twist_cols[i]; + if inst.lanes.is_empty() { + continue; + } + + if mem_id == REG_ID.0 { + let lane0 = trace_twist_primary_binding(layout, mem_id, decode_selectors); + builder.add_twist_instance_bound(&bus, &inst.lanes[0], &lane0); + let lane1 = TwistCpuBinding { + has_read: trace_cpu_col(layout, layout.trace.active), + has_write: CPU_BUS_COL_DISABLED, + read_addr: trace_cpu_col(layout, layout.trace.rs2_addr), + write_addr: CPU_BUS_COL_DISABLED, + rv: trace_cpu_col(layout, layout.trace.rs2_val), + wv: CPU_BUS_COL_DISABLED, + inc: None, + }; + if inst.lanes.len() >= 2 { + builder.add_twist_instance_bound(&bus, &inst.lanes[1], &lane1); + } + if inst.lanes.len() > 2 { + let disabled = trace_disabled_twist_binding(layout); + for lane_cols in &inst.lanes[2..] { + builder.add_twist_instance_bound(&bus, lane_cols, &disabled); + } + } + } else { + let lane0 = trace_twist_primary_binding(layout, mem_id, decode_selectors); + builder.add_twist_instance_bound(&bus, &inst.lanes[0], &lane0); + if inst.lanes.len() > 1 { + let disabled = trace_disabled_twist_binding(layout); + for lane_cols in &inst.lanes[1..] { + builder.add_twist_instance_bound(&bus, lane_cols, &disabled); + } + } + } + } + + audit_bus_tail_constraint_coverage(&builder, &bus)?; + + Ok(TraceSharedBusExtraction { + bus, + constraints: builder.constraints().to_vec(), }) } diff --git a/crates/neo-memory/src/riscv/ccs/config.rs b/crates/neo-memory/src/riscv/ccs/config.rs deleted file mode 100644 index 3f3e2809..00000000 --- a/crates/neo-memory/src/riscv/ccs/config.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::collections::HashMap; - -use crate::plain::PlainMemLayout; - -use super::constants::{ADD_TABLE_ID, REMU_TABLE_ID, RV32_XLEN}; - -pub(super) fn derive_mem_ids_and_ell_addrs( - mem_layouts: &HashMap, -) -> Result<(Vec, Vec), String> { - let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); - mem_ids.sort_unstable(); - - let mut twist_ell_addrs = Vec::with_capacity(mem_ids.len()); - for mem_id in &mem_ids { - let layout = mem_layouts - .get(mem_id) - .ok_or_else(|| format!("missing mem_layout for mem_id={mem_id}"))?; - if layout.n_side == 0 || !layout.n_side.is_power_of_two() { - return Err(format!( - "mem_id={mem_id}: n_side={} must be power of two", - layout.n_side - )); - } - let ell = layout.n_side.trailing_zeros() as usize; - twist_ell_addrs.push(layout.d * ell); - } - - Ok((mem_ids, twist_ell_addrs)) -} - -pub(super) fn derive_shout_ids_and_ell_addrs(shout_table_ids: &[u32]) -> Result<(Vec, Vec), String> { - let mut table_ids: Vec = shout_table_ids.to_vec(); - table_ids.sort_unstable(); - table_ids.dedup(); - if table_ids.is_empty() { - return Err("RV32 B1: shout_table_ids must be non-empty".into()); - } - if !table_ids.contains(&ADD_TABLE_ID) { - return Err(format!( - "RV32 B1: shout_table_ids must include ADD table_id={ADD_TABLE_ID}" - )); - } - // This circuit supports RV32IM via the 20 base+M opcode tables (ids 0..=19). Callers may pass any - // subset, as long as it covers the opcodes that will actually appear in the VM trace. - // (Missing table specs are rejected by `build_shard_witness_shared_cpu_bus` when the trace contains - // a Shout event for an unlisted `table_id`.) - for &table_id in &table_ids { - if table_id > REMU_TABLE_ID { - return Err(format!( - "RV32 B1: unsupported table_id={table_id} (expected RISC-V opcode table ids 0..={REMU_TABLE_ID})" - )); - } - } - // MVP: every Shout table in this circuit is a RISC-V opcode table with d=2*xlen, n_side=2 => ell_addr=2*xlen. - let shout_ell_addrs = vec![2 * RV32_XLEN; table_ids.len()]; - Ok((table_ids, shout_ell_addrs)) -} diff --git a/crates/neo-memory/src/riscv/ccs/constraint_builder.rs b/crates/neo-memory/src/riscv/ccs/constraint_builder.rs index 0ba8ca4f..6ab7f7cc 100644 --- a/crates/neo-memory/src/riscv/ccs/constraint_builder.rs +++ b/crates/neo-memory/src/riscv/ccs/constraint_builder.rs @@ -14,26 +14,6 @@ pub(super) struct Constraint { } impl Constraint { - pub fn eq_const(condition_col: usize, const_one_col: usize, left: usize, c: u64) -> Self { - Self { - condition_col, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(left, Ff::ONE), (const_one_col, -Ff::from_u64(c))], - c_terms: Vec::new(), - } - } - - pub fn zero(condition_col: usize, col: usize) -> Self { - Self { - condition_col, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(col, Ff::ONE)], - c_terms: Vec::new(), - } - } - pub fn terms(condition_col: usize, negate_condition: bool, b_terms: Vec<(usize, Ff)>) -> Self { Self { condition_col, @@ -43,27 +23,6 @@ impl Constraint { c_terms: Vec::new(), } } - - pub fn terms_or(condition_cols: &[usize], negate_condition: bool, b_terms: Vec<(usize, Ff)>) -> Self { - assert!(!condition_cols.is_empty(), "need at least one condition column"); - Self { - condition_col: condition_cols[0], - negate_condition, - additional_condition_cols: condition_cols[1..].to_vec(), - b_terms, - c_terms: Vec::new(), - } - } - - pub fn mul(left: usize, right: usize, out: usize) -> Self { - Self { - condition_col: left, - negate_condition: false, - additional_condition_cols: Vec::new(), - b_terms: vec![(right, Ff::ONE)], - c_terms: vec![(out, Ff::ONE)], - } - } } pub(super) fn build_r1cs_ccs( @@ -73,17 +32,19 @@ pub(super) fn build_r1cs_ccs( const_one_col: usize, ) -> Result, String> { if m == 0 { - return Err("RV32 B1 CCS: m must be >= 1".into()); + return Err("RV32 trace CCS: m must be >= 1".into()); } if n == 0 { - return Err("RV32 B1 CCS: n must be >= 1".into()); + return Err("RV32 trace CCS: n must be >= 1".into()); } if const_one_col >= m { - return Err(format!("RV32 B1 CCS: const_one_col({const_one_col}) must be < m({m})")); + return Err(format!( + "RV32 trace CCS: const_one_col({const_one_col}) must be < m({m})" + )); } if constraints.len() > n { return Err(format!( - "RV32 B1 CCS: too many constraints ({}) for CCS with n={} m={}", + "RV32 trace CCS: too many constraints ({}) for CCS with n={} m={}", constraints.len(), n, m @@ -137,12 +98,7 @@ pub(super) fn build_r1cs_ccs( ], ); - // Match `neo_ccs::r1cs_to_ccs`: insert identity-first only when square. - let (matrices, f) = if n == m { - (vec![CcsMatrix::Identity { n }, a, b, c], f_base.insert_var_at_front()) - } else { - (vec![a, b, c], f_base) - }; + let matrices = vec![a, b, c]; - CcsStructure::new_sparse(matrices, f).map_err(|e| format!("RV32 B1 CCS: invalid structure: {e:?}")) + CcsStructure::new_sparse(matrices, f_base).map_err(|e| format!("RV32 trace CCS: invalid structure: {e:?}")) } diff --git a/crates/neo-memory/src/riscv/ccs/layout.rs b/crates/neo-memory/src/riscv/ccs/layout.rs deleted file mode 100644 index 3d4fca2b..00000000 --- a/crates/neo-memory/src/riscv/ccs/layout.rs +++ /dev/null @@ -1,1326 +0,0 @@ -use std::collections::HashMap; - -use crate::cpu::bus_layout::{build_bus_layout_for_instances, BusLayout}; -use crate::plain::PlainMemLayout; -use crate::riscv::lookups::{PROG_ID, RAM_ID}; - -use super::config::{derive_mem_ids_and_ell_addrs, derive_shout_ids_and_ell_addrs}; - -/// Witness/column layout for the RV32 B1 step circuit. -#[derive(Clone, Debug)] -pub struct Rv32B1Layout { - pub m_in: usize, - pub m: usize, - pub chunk_size: usize, - pub const_one: usize, - // Public I/O (single values per chunk). - pub pc0: usize, - pub regs0_start: usize, // 32 cols - pub pc_final: usize, - pub regs_final_start: usize, // 32 cols - pub halted_in: usize, - pub halted_out: usize, - pub is_active: usize, - - pub pc_in: usize, - pub pc_out: usize, - pub instr_word: usize, - - pub regs_in_start: usize, - pub regs_out_start: usize, - - pub instr_bits_start: usize, // 32 bits - - pub opcode: usize, - pub funct3: usize, - pub funct7: usize, - pub rd_field: usize, - pub rs1_field: usize, - pub rs2_field: usize, - - pub imm12_raw: usize, - pub imm_i: usize, - pub imm_s: usize, - pub imm_u: usize, - pub imm_b_raw: usize, - pub imm_b: usize, - pub imm_j_raw: usize, - pub imm_j: usize, - - // One-hot instruction flags (sum == is_active). - pub is_add: usize, - pub is_sub: usize, - pub is_sll: usize, - pub is_slt: usize, - pub is_sltu: usize, - pub is_xor: usize, - pub is_srl: usize, - pub is_sra: usize, - pub is_or: usize, - pub is_and: usize, - - // RV32M (R-type, funct7=0b0000001). - pub is_mul: usize, - pub is_mulh: usize, - pub is_mulhu: usize, - pub is_mulhsu: usize, - pub is_div: usize, - pub is_divu: usize, - pub is_rem: usize, - pub is_remu: usize, - - pub is_addi: usize, - pub is_slti: usize, - pub is_sltiu: usize, - pub is_xori: usize, - pub is_ori: usize, - pub is_andi: usize, - pub is_slli: usize, - pub is_srli: usize, - pub is_srai: usize, - - pub is_lb: usize, - pub is_lbu: usize, - pub is_lh: usize, - pub is_lhu: usize, - pub is_lw: usize, - pub is_sb: usize, - pub is_sh: usize, - pub is_sw: usize, - // RV32A (atomics, word only). - pub is_amoswap_w: usize, - pub is_amoadd_w: usize, - pub is_amoxor_w: usize, - pub is_amoor_w: usize, - pub is_amoand_w: usize, - pub is_lui: usize, - pub is_auipc: usize, - pub is_beq: usize, - pub is_bne: usize, - pub is_blt: usize, - pub is_bge: usize, - pub is_bltu: usize, - pub is_bgeu: usize, - pub is_jal: usize, - pub is_jalr: usize, - pub is_fence: usize, - pub is_halt: usize, - - pub br_taken: usize, - pub br_not_taken: usize, - - pub rs1_sel_start: usize, // 32 - pub rs2_sel_start: usize, // 32 - pub rd_sel_start: usize, // 32 - - pub rs1_val: usize, - pub rs2_val: usize, - - pub alu_out: usize, - pub mem_rv: usize, - pub mem_rv_bits_start: usize, // 32 - pub eff_addr: usize, - // RAM bus selectors/values (must be tied to instruction flags to avoid bypassing Twist). - pub ram_has_read: usize, - pub ram_has_write: usize, - pub ram_wv: usize, - pub rd_write_val: usize, - pub rd_write_bits_start: usize, // 32 - - pub add_has_lookup: usize, - pub and_has_lookup: usize, - pub xor_has_lookup: usize, - pub or_has_lookup: usize, - pub sll_has_lookup: usize, - pub srl_has_lookup: usize, - pub sra_has_lookup: usize, - pub slt_has_lookup: usize, - pub sltu_has_lookup: usize, - pub lookup_key: usize, - pub add_a0b0: usize, - - // In-circuit RV32M helpers (avoid requiring implicit Shout tables). - // MUL* helpers: rs1_val * rs2_val = mul_lo + 2^32 * mul_hi - pub mul_lo: usize, - pub mul_hi: usize, - pub mul_lo_bits_start: usize, // 32 - pub mul_hi_bits_start: usize, // 32 - pub mul_hi_prefix_start: usize, // 31 - pub mul_carry: usize, - pub mul_carry_bits_start: usize, // 2 - - // Signed helpers: rs1/rs2 bits + absolute values. - pub rs1_bits_start: usize, // 32 - pub rs2_bits_start: usize, // 32 - pub rs2_zero_prefix_start: usize, // 31 - pub rs1_abs: usize, - pub rs2_abs: usize, - pub rs1_rs2_sign_and: usize, - pub rs1_sign_rs2_val: usize, - pub rs2_sign_rs1_val: usize, - - // DIV/REM helpers (unsigned + signed). - pub div_quot: usize, - pub div_rem: usize, - pub div_quot_signed: usize, - pub div_rem_signed: usize, - pub div_quot_carry: usize, - pub div_rem_carry: usize, - pub div_prod: usize, - pub div_divisor: usize, - pub div_quot_bits_start: usize, // 32 - pub div_rem_bits_start: usize, // 32 - pub rs2_is_zero: usize, - pub rs2_nonzero: usize, - pub is_divu_or_remu: usize, - pub divu_by_zero: usize, - pub is_div_or_rem: usize, - pub div_nonzero: usize, - pub rem_nonzero: usize, - pub div_by_zero: usize, - pub rem_by_zero: usize, - pub div_sign: usize, - pub div_rem_check: usize, - pub div_rem_check_signed: usize, - // ECALL helpers (Jolt marker/print IDs). - pub ecall_a0_bits_start: usize, // 32 - pub ecall_cycle_prefix_start: usize, // 31 - pub ecall_is_cycle: usize, - pub ecall_print_prefix_start: usize, // 31 - pub ecall_is_print: usize, - pub ecall_halts: usize, - pub halt_effective: usize, - - pub bus: BusLayout, - pub mem_ids: Vec, - pub table_ids: Vec, - pub ram_twist_idx: usize, - pub prog_twist_idx: usize, -} - -impl Rv32B1Layout { - #[inline] - fn cpu_cell(&self, base: usize, j: usize) -> usize { - debug_assert!(j < self.chunk_size, "cpu j out of range"); - base + j - } - - #[inline] - pub fn is_active(&self, j: usize) -> usize { - self.cpu_cell(self.is_active, j) - } - - #[inline] - pub fn pc_in(&self, j: usize) -> usize { - self.cpu_cell(self.pc_in, j) - } - - #[inline] - pub fn pc_out(&self, j: usize) -> usize { - self.cpu_cell(self.pc_out, j) - } - - #[inline] - pub fn instr_word(&self, j: usize) -> usize { - self.cpu_cell(self.instr_word, j) - } - - pub fn reg_in(&self, r: usize, j: usize) -> usize { - assert!(r < 32); - self.regs_in_start + r * self.chunk_size + j - } - - pub fn reg_out(&self, r: usize, j: usize) -> usize { - assert!(r < 32); - self.regs_out_start + r * self.chunk_size + j - } - - pub fn instr_bit(&self, i: usize, j: usize) -> usize { - assert!(i < 32); - self.instr_bits_start + i * self.chunk_size + j - } - - pub fn rs1_sel(&self, r: usize, j: usize) -> usize { - assert!(r < 32); - self.rs1_sel_start + r * self.chunk_size + j - } - - pub fn rs2_sel(&self, r: usize, j: usize) -> usize { - assert!(r < 32); - self.rs2_sel_start + r * self.chunk_size + j - } - - pub fn rd_sel(&self, r: usize, j: usize) -> usize { - assert!(r < 32); - self.rd_sel_start + r * self.chunk_size + j - } - - #[inline] - pub fn rs1_val(&self, j: usize) -> usize { - self.cpu_cell(self.rs1_val, j) - } - - #[inline] - pub fn rs2_val(&self, j: usize) -> usize { - self.cpu_cell(self.rs2_val, j) - } - - #[inline] - pub fn alu_out(&self, j: usize) -> usize { - self.cpu_cell(self.alu_out, j) - } - - #[inline] - pub fn mem_rv(&self, j: usize) -> usize { - self.cpu_cell(self.mem_rv, j) - } - - pub fn mem_rv_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 32); - self.mem_rv_bits_start + bit * self.chunk_size + j - } - - #[inline] - pub fn eff_addr(&self, j: usize) -> usize { - self.cpu_cell(self.eff_addr, j) - } - - #[inline] - pub fn ram_has_read(&self, j: usize) -> usize { - self.cpu_cell(self.ram_has_read, j) - } - - #[inline] - pub fn ram_has_write(&self, j: usize) -> usize { - self.cpu_cell(self.ram_has_write, j) - } - - #[inline] - pub fn ram_wv(&self, j: usize) -> usize { - self.cpu_cell(self.ram_wv, j) - } - - #[inline] - pub fn rd_write_val(&self, j: usize) -> usize { - self.cpu_cell(self.rd_write_val, j) - } - - pub fn rd_write_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 32); - self.rd_write_bits_start + bit * self.chunk_size + j - } - - #[inline] - pub fn lookup_key(&self, j: usize) -> usize { - self.cpu_cell(self.lookup_key, j) - } - - #[inline] - pub fn add_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.add_has_lookup, j) - } - - #[inline] - pub fn and_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.and_has_lookup, j) - } - - #[inline] - pub fn xor_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.xor_has_lookup, j) - } - - #[inline] - pub fn or_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.or_has_lookup, j) - } - - #[inline] - pub fn mul_lo(&self, j: usize) -> usize { - self.cpu_cell(self.mul_lo, j) - } - - pub fn mul_lo_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 32); - self.mul_lo_bits_start + bit * self.chunk_size + j - } - - #[inline] - pub fn mul_hi(&self, j: usize) -> usize { - self.cpu_cell(self.mul_hi, j) - } - - pub fn mul_hi_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 32); - self.mul_hi_bits_start + bit * self.chunk_size + j - } - - #[inline] - pub fn mul_hi_prefix(&self, k: usize, j: usize) -> usize { - assert!(k < 31); - self.mul_hi_prefix_start + k * self.chunk_size + j - } - - #[inline] - pub fn mul_carry(&self, j: usize) -> usize { - self.cpu_cell(self.mul_carry, j) - } - - pub fn mul_carry_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 2); - self.mul_carry_bits_start + bit * self.chunk_size + j - } - - #[inline] - pub fn div_quot(&self, j: usize) -> usize { - self.cpu_cell(self.div_quot, j) - } - - #[inline] - pub fn div_rem(&self, j: usize) -> usize { - self.cpu_cell(self.div_rem, j) - } - - #[inline] - pub fn div_quot_signed(&self, j: usize) -> usize { - self.cpu_cell(self.div_quot_signed, j) - } - - #[inline] - pub fn div_rem_signed(&self, j: usize) -> usize { - self.cpu_cell(self.div_rem_signed, j) - } - - #[inline] - pub fn div_quot_carry(&self, j: usize) -> usize { - self.cpu_cell(self.div_quot_carry, j) - } - - #[inline] - pub fn div_rem_carry(&self, j: usize) -> usize { - self.cpu_cell(self.div_rem_carry, j) - } - - #[inline] - pub fn div_prod(&self, j: usize) -> usize { - self.cpu_cell(self.div_prod, j) - } - - #[inline] - pub fn div_divisor(&self, j: usize) -> usize { - self.cpu_cell(self.div_divisor, j) - } - - pub fn div_quot_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 32); - self.div_quot_bits_start + bit * self.chunk_size + j - } - - pub fn div_rem_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 32); - self.div_rem_bits_start + bit * self.chunk_size + j - } - - pub fn rs1_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 32); - self.rs1_bits_start + bit * self.chunk_size + j - } - - #[inline] - pub fn rs2_is_zero(&self, j: usize) -> usize { - self.cpu_cell(self.rs2_is_zero, j) - } - - pub fn rs2_bit(&self, bit: usize, j: usize) -> usize { - assert!(bit < 32); - self.rs2_bits_start + bit * self.chunk_size + j - } - - pub fn rs2_zero_prefix(&self, idx: usize, j: usize) -> usize { - assert!(idx < 31); - self.rs2_zero_prefix_start + idx * self.chunk_size + j - } - - #[inline] - pub fn rs2_nonzero(&self, j: usize) -> usize { - self.cpu_cell(self.rs2_nonzero, j) - } - - #[inline] - pub fn rs1_abs(&self, j: usize) -> usize { - self.cpu_cell(self.rs1_abs, j) - } - - #[inline] - pub fn rs2_abs(&self, j: usize) -> usize { - self.cpu_cell(self.rs2_abs, j) - } - - #[inline] - pub fn rs1_rs2_sign_and(&self, j: usize) -> usize { - self.cpu_cell(self.rs1_rs2_sign_and, j) - } - - #[inline] - pub fn rs1_sign_rs2_val(&self, j: usize) -> usize { - self.cpu_cell(self.rs1_sign_rs2_val, j) - } - - #[inline] - pub fn rs2_sign_rs1_val(&self, j: usize) -> usize { - self.cpu_cell(self.rs2_sign_rs1_val, j) - } - - #[inline] - pub fn is_divu_or_remu(&self, j: usize) -> usize { - self.cpu_cell(self.is_divu_or_remu, j) - } - - #[inline] - pub fn divu_by_zero(&self, j: usize) -> usize { - self.cpu_cell(self.divu_by_zero, j) - } - - #[inline] - pub fn is_div_or_rem(&self, j: usize) -> usize { - self.cpu_cell(self.is_div_or_rem, j) - } - - #[inline] - pub fn div_nonzero(&self, j: usize) -> usize { - self.cpu_cell(self.div_nonzero, j) - } - - #[inline] - pub fn rem_nonzero(&self, j: usize) -> usize { - self.cpu_cell(self.rem_nonzero, j) - } - - #[inline] - pub fn div_by_zero(&self, j: usize) -> usize { - self.cpu_cell(self.div_by_zero, j) - } - - #[inline] - pub fn rem_by_zero(&self, j: usize) -> usize { - self.cpu_cell(self.rem_by_zero, j) - } - - #[inline] - pub fn div_sign(&self, j: usize) -> usize { - self.cpu_cell(self.div_sign, j) - } - - #[inline] - pub fn div_rem_check(&self, j: usize) -> usize { - self.cpu_cell(self.div_rem_check, j) - } - - #[inline] - pub fn div_rem_check_signed(&self, j: usize) -> usize { - self.cpu_cell(self.div_rem_check_signed, j) - } - - #[inline] - pub fn ecall_a0_bit(&self, bit: usize, j: usize) -> usize { - debug_assert!(bit < 32, "a0 bit out of range"); - self.ecall_a0_bits_start + bit * self.chunk_size + j - } - - #[inline] - pub fn ecall_cycle_prefix(&self, k: usize, j: usize) -> usize { - debug_assert!(k < 31, "ecall_cycle_prefix k out of range"); - self.ecall_cycle_prefix_start + k * self.chunk_size + j - } - - #[inline] - pub fn ecall_is_cycle(&self, j: usize) -> usize { - self.cpu_cell(self.ecall_is_cycle, j) - } - - #[inline] - pub fn ecall_print_prefix(&self, k: usize, j: usize) -> usize { - debug_assert!(k < 31, "ecall_print_prefix k out of range"); - self.ecall_print_prefix_start + k * self.chunk_size + j - } - - #[inline] - pub fn ecall_is_print(&self, j: usize) -> usize { - self.cpu_cell(self.ecall_is_print, j) - } - - #[inline] - pub fn ecall_halts(&self, j: usize) -> usize { - self.cpu_cell(self.ecall_halts, j) - } - - #[inline] - pub fn halt_effective(&self, j: usize) -> usize { - self.cpu_cell(self.halt_effective, j) - } - - #[inline] - pub fn sll_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.sll_has_lookup, j) - } - - #[inline] - pub fn srl_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.srl_has_lookup, j) - } - - #[inline] - pub fn sra_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.sra_has_lookup, j) - } - - #[inline] - pub fn slt_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.slt_has_lookup, j) - } - - #[inline] - pub fn sltu_has_lookup(&self, j: usize) -> usize { - self.cpu_cell(self.sltu_has_lookup, j) - } - - #[inline] - pub fn add_a0b0(&self, j: usize) -> usize { - self.cpu_cell(self.add_a0b0, j) - } - - #[inline] - pub fn imm12_raw(&self, j: usize) -> usize { - self.cpu_cell(self.imm12_raw, j) - } - - #[inline] - pub fn imm_i(&self, j: usize) -> usize { - self.cpu_cell(self.imm_i, j) - } - - #[inline] - pub fn imm_s(&self, j: usize) -> usize { - self.cpu_cell(self.imm_s, j) - } - - #[inline] - pub fn imm_u(&self, j: usize) -> usize { - self.cpu_cell(self.imm_u, j) - } - - #[inline] - pub fn imm_b_raw(&self, j: usize) -> usize { - self.cpu_cell(self.imm_b_raw, j) - } - - #[inline] - pub fn imm_b(&self, j: usize) -> usize { - self.cpu_cell(self.imm_b, j) - } - - #[inline] - pub fn imm_j_raw(&self, j: usize) -> usize { - self.cpu_cell(self.imm_j_raw, j) - } - - #[inline] - pub fn imm_j(&self, j: usize) -> usize { - self.cpu_cell(self.imm_j, j) - } - - #[inline] - pub fn shamt(&self, j: usize) -> usize { - // Shift amount lives in the same 5-bit field as `rs2_field` (instr bits [24:20]). - self.rs2_field(j) - } - - #[inline] - pub fn opcode(&self, j: usize) -> usize { - self.cpu_cell(self.opcode, j) - } - - #[inline] - pub fn funct3(&self, j: usize) -> usize { - self.cpu_cell(self.funct3, j) - } - - #[inline] - pub fn funct7(&self, j: usize) -> usize { - self.cpu_cell(self.funct7, j) - } - - #[inline] - pub fn rd_field(&self, j: usize) -> usize { - self.cpu_cell(self.rd_field, j) - } - - #[inline] - pub fn rs1_field(&self, j: usize) -> usize { - self.cpu_cell(self.rs1_field, j) - } - - #[inline] - pub fn rs2_field(&self, j: usize) -> usize { - self.cpu_cell(self.rs2_field, j) - } - - #[inline] - pub fn is_add(&self, j: usize) -> usize { - self.cpu_cell(self.is_add, j) - } - - #[inline] - pub fn is_sub(&self, j: usize) -> usize { - self.cpu_cell(self.is_sub, j) - } - - #[inline] - pub fn is_sll(&self, j: usize) -> usize { - self.cpu_cell(self.is_sll, j) - } - - #[inline] - pub fn is_slt(&self, j: usize) -> usize { - self.cpu_cell(self.is_slt, j) - } - - #[inline] - pub fn is_sltu(&self, j: usize) -> usize { - self.cpu_cell(self.is_sltu, j) - } - - #[inline] - pub fn is_xor(&self, j: usize) -> usize { - self.cpu_cell(self.is_xor, j) - } - - #[inline] - pub fn is_srl(&self, j: usize) -> usize { - self.cpu_cell(self.is_srl, j) - } - - #[inline] - pub fn is_sra(&self, j: usize) -> usize { - self.cpu_cell(self.is_sra, j) - } - - #[inline] - pub fn is_or(&self, j: usize) -> usize { - self.cpu_cell(self.is_or, j) - } - - #[inline] - pub fn is_and(&self, j: usize) -> usize { - self.cpu_cell(self.is_and, j) - } - - #[inline] - pub fn is_mul(&self, j: usize) -> usize { - self.cpu_cell(self.is_mul, j) - } - - #[inline] - pub fn is_mulh(&self, j: usize) -> usize { - self.cpu_cell(self.is_mulh, j) - } - - #[inline] - pub fn is_mulhu(&self, j: usize) -> usize { - self.cpu_cell(self.is_mulhu, j) - } - - #[inline] - pub fn is_mulhsu(&self, j: usize) -> usize { - self.cpu_cell(self.is_mulhsu, j) - } - - #[inline] - pub fn is_div(&self, j: usize) -> usize { - self.cpu_cell(self.is_div, j) - } - - #[inline] - pub fn is_divu(&self, j: usize) -> usize { - self.cpu_cell(self.is_divu, j) - } - - #[inline] - pub fn is_rem(&self, j: usize) -> usize { - self.cpu_cell(self.is_rem, j) - } - - #[inline] - pub fn is_remu(&self, j: usize) -> usize { - self.cpu_cell(self.is_remu, j) - } - - #[inline] - pub fn is_addi(&self, j: usize) -> usize { - self.cpu_cell(self.is_addi, j) - } - - #[inline] - pub fn is_slti(&self, j: usize) -> usize { - self.cpu_cell(self.is_slti, j) - } - - #[inline] - pub fn is_sltiu(&self, j: usize) -> usize { - self.cpu_cell(self.is_sltiu, j) - } - - #[inline] - pub fn is_xori(&self, j: usize) -> usize { - self.cpu_cell(self.is_xori, j) - } - - #[inline] - pub fn is_ori(&self, j: usize) -> usize { - self.cpu_cell(self.is_ori, j) - } - - #[inline] - pub fn is_andi(&self, j: usize) -> usize { - self.cpu_cell(self.is_andi, j) - } - - #[inline] - pub fn is_slli(&self, j: usize) -> usize { - self.cpu_cell(self.is_slli, j) - } - - #[inline] - pub fn is_srli(&self, j: usize) -> usize { - self.cpu_cell(self.is_srli, j) - } - - #[inline] - pub fn is_srai(&self, j: usize) -> usize { - self.cpu_cell(self.is_srai, j) - } - - #[inline] - pub fn is_lb(&self, j: usize) -> usize { - self.cpu_cell(self.is_lb, j) - } - - #[inline] - pub fn is_lbu(&self, j: usize) -> usize { - self.cpu_cell(self.is_lbu, j) - } - - #[inline] - pub fn is_lh(&self, j: usize) -> usize { - self.cpu_cell(self.is_lh, j) - } - - #[inline] - pub fn is_lhu(&self, j: usize) -> usize { - self.cpu_cell(self.is_lhu, j) - } - - #[inline] - pub fn is_lw(&self, j: usize) -> usize { - self.cpu_cell(self.is_lw, j) - } - - #[inline] - pub fn is_sb(&self, j: usize) -> usize { - self.cpu_cell(self.is_sb, j) - } - - #[inline] - pub fn is_sh(&self, j: usize) -> usize { - self.cpu_cell(self.is_sh, j) - } - - #[inline] - pub fn is_sw(&self, j: usize) -> usize { - self.cpu_cell(self.is_sw, j) - } - - #[inline] - pub fn is_amoswap_w(&self, j: usize) -> usize { - self.cpu_cell(self.is_amoswap_w, j) - } - - #[inline] - pub fn is_amoadd_w(&self, j: usize) -> usize { - self.cpu_cell(self.is_amoadd_w, j) - } - - #[inline] - pub fn is_amoxor_w(&self, j: usize) -> usize { - self.cpu_cell(self.is_amoxor_w, j) - } - - #[inline] - pub fn is_amoor_w(&self, j: usize) -> usize { - self.cpu_cell(self.is_amoor_w, j) - } - - #[inline] - pub fn is_amoand_w(&self, j: usize) -> usize { - self.cpu_cell(self.is_amoand_w, j) - } - - #[inline] - pub fn is_lui(&self, j: usize) -> usize { - self.cpu_cell(self.is_lui, j) - } - - #[inline] - pub fn is_auipc(&self, j: usize) -> usize { - self.cpu_cell(self.is_auipc, j) - } - - #[inline] - pub fn is_beq(&self, j: usize) -> usize { - self.cpu_cell(self.is_beq, j) - } - - #[inline] - pub fn is_bne(&self, j: usize) -> usize { - self.cpu_cell(self.is_bne, j) - } - - #[inline] - pub fn is_blt(&self, j: usize) -> usize { - self.cpu_cell(self.is_blt, j) - } - - #[inline] - pub fn is_bge(&self, j: usize) -> usize { - self.cpu_cell(self.is_bge, j) - } - - #[inline] - pub fn is_bltu(&self, j: usize) -> usize { - self.cpu_cell(self.is_bltu, j) - } - - #[inline] - pub fn is_bgeu(&self, j: usize) -> usize { - self.cpu_cell(self.is_bgeu, j) - } - - #[inline] - pub fn is_jal(&self, j: usize) -> usize { - self.cpu_cell(self.is_jal, j) - } - - #[inline] - pub fn is_jalr(&self, j: usize) -> usize { - self.cpu_cell(self.is_jalr, j) - } - - #[inline] - pub fn is_fence(&self, j: usize) -> usize { - self.cpu_cell(self.is_fence, j) - } - - #[inline] - pub fn is_halt(&self, j: usize) -> usize { - self.cpu_cell(self.is_halt, j) - } - - #[inline] - pub fn br_taken(&self, j: usize) -> usize { - self.cpu_cell(self.br_taken, j) - } - - #[inline] - pub fn br_not_taken(&self, j: usize) -> usize { - self.cpu_cell(self.br_not_taken, j) - } - - pub fn shout_idx(&self, table_id: u32) -> Result { - self.table_ids - .binary_search(&table_id) - .map_err(|_| format!("RV32 B1: table_ids missing required table_id={table_id}")) - } -} - -pub(super) fn build_layout_with_m( - m: usize, - mem_layouts: &HashMap, - shout_table_ids: &[u32], - chunk_size: usize, -) -> Result { - if chunk_size == 0 { - return Err("RV32 B1 layout: chunk_size must be >= 1".into()); - } - let const_one = 0usize; - - // Public inputs: initial and final architectural state. - // Layout: [const_one, pc0, regs0[32], pc_final, regs_final[32], halted_in, halted_out] - let pc0 = 1usize; - let regs0_start = pc0 + 1; - let pc_final = regs0_start + 32; - let regs_final_start = pc_final + 1; - let halted_in = regs_final_start + 32; - let halted_out = halted_in + 1; - let m_in = halted_out + 1; - - // Fixed CPU column allocation (CPU region only). All indices must be < bus.bus_base. - let mut col = m_in; - let alloc_scalar = |col: &mut usize| { - let base = *col; - *col += chunk_size; - base - }; - let alloc_array = |col: &mut usize, n: usize| { - let base = *col; - *col += n * chunk_size; - base - }; - - let is_active = alloc_scalar(&mut col); - let pc_in = alloc_scalar(&mut col); - let pc_out = alloc_scalar(&mut col); - let instr_word = alloc_scalar(&mut col); - - let regs_in_start = alloc_array(&mut col, 32); - let regs_out_start = alloc_array(&mut col, 32); - - let instr_bits_start = alloc_array(&mut col, 32); - - let opcode = alloc_scalar(&mut col); - let funct3 = alloc_scalar(&mut col); - let funct7 = alloc_scalar(&mut col); - let rd_field = alloc_scalar(&mut col); - let rs1_field = alloc_scalar(&mut col); - let rs2_field = alloc_scalar(&mut col); - - let imm12_raw = alloc_scalar(&mut col); - let imm_i = alloc_scalar(&mut col); - let imm_s = alloc_scalar(&mut col); - let imm_u = alloc_scalar(&mut col); - let imm_b_raw = alloc_scalar(&mut col); - let imm_b = alloc_scalar(&mut col); - let imm_j_raw = alloc_scalar(&mut col); - let imm_j = alloc_scalar(&mut col); - - let is_add = alloc_scalar(&mut col); - let is_sub = alloc_scalar(&mut col); - let is_sll = alloc_scalar(&mut col); - let is_slt = alloc_scalar(&mut col); - let is_sltu = alloc_scalar(&mut col); - let is_xor = alloc_scalar(&mut col); - let is_srl = alloc_scalar(&mut col); - let is_sra = alloc_scalar(&mut col); - let is_or = alloc_scalar(&mut col); - let is_and = alloc_scalar(&mut col); - - let is_mul = alloc_scalar(&mut col); - let is_mulh = alloc_scalar(&mut col); - let is_mulhu = alloc_scalar(&mut col); - let is_mulhsu = alloc_scalar(&mut col); - let is_div = alloc_scalar(&mut col); - let is_divu = alloc_scalar(&mut col); - let is_rem = alloc_scalar(&mut col); - let is_remu = alloc_scalar(&mut col); - - let is_addi = alloc_scalar(&mut col); - let is_slti = alloc_scalar(&mut col); - let is_sltiu = alloc_scalar(&mut col); - let is_xori = alloc_scalar(&mut col); - let is_ori = alloc_scalar(&mut col); - let is_andi = alloc_scalar(&mut col); - let is_slli = alloc_scalar(&mut col); - let is_srli = alloc_scalar(&mut col); - let is_srai = alloc_scalar(&mut col); - let is_lb = alloc_scalar(&mut col); - let is_lbu = alloc_scalar(&mut col); - let is_lh = alloc_scalar(&mut col); - let is_lhu = alloc_scalar(&mut col); - let is_lw = alloc_scalar(&mut col); - let is_sb = alloc_scalar(&mut col); - let is_sh = alloc_scalar(&mut col); - let is_sw = alloc_scalar(&mut col); - let is_amoswap_w = alloc_scalar(&mut col); - let is_amoadd_w = alloc_scalar(&mut col); - let is_amoxor_w = alloc_scalar(&mut col); - let is_amoor_w = alloc_scalar(&mut col); - let is_amoand_w = alloc_scalar(&mut col); - let is_lui = alloc_scalar(&mut col); - let is_auipc = alloc_scalar(&mut col); - let is_beq = alloc_scalar(&mut col); - let is_bne = alloc_scalar(&mut col); - let is_blt = alloc_scalar(&mut col); - let is_bge = alloc_scalar(&mut col); - let is_bltu = alloc_scalar(&mut col); - let is_bgeu = alloc_scalar(&mut col); - let is_jal = alloc_scalar(&mut col); - let is_jalr = alloc_scalar(&mut col); - let is_fence = alloc_scalar(&mut col); - let is_halt = alloc_scalar(&mut col); - - let br_taken = alloc_scalar(&mut col); - let br_not_taken = alloc_scalar(&mut col); - - let rs1_sel_start = alloc_array(&mut col, 32); - let rs2_sel_start = alloc_array(&mut col, 32); - let rd_sel_start = alloc_array(&mut col, 32); - - let rs1_val = alloc_scalar(&mut col); - let rs2_val = alloc_scalar(&mut col); - - let alu_out = alloc_scalar(&mut col); - let mem_rv = alloc_scalar(&mut col); - let mem_rv_bits_start = alloc_array(&mut col, 32); - let eff_addr = alloc_scalar(&mut col); - let ram_has_read = alloc_scalar(&mut col); - let ram_has_write = alloc_scalar(&mut col); - let ram_wv = alloc_scalar(&mut col); - let rd_write_val = alloc_scalar(&mut col); - let rd_write_bits_start = alloc_array(&mut col, 32); - - let add_has_lookup = alloc_scalar(&mut col); - let and_has_lookup = alloc_scalar(&mut col); - let xor_has_lookup = alloc_scalar(&mut col); - let or_has_lookup = alloc_scalar(&mut col); - let sll_has_lookup = alloc_scalar(&mut col); - let srl_has_lookup = alloc_scalar(&mut col); - let sra_has_lookup = alloc_scalar(&mut col); - let slt_has_lookup = alloc_scalar(&mut col); - let sltu_has_lookup = alloc_scalar(&mut col); - let lookup_key = alloc_scalar(&mut col); - let add_a0b0 = alloc_scalar(&mut col); - - // In-circuit RV32M helpers. - let mul_lo = alloc_scalar(&mut col); - let mul_hi = alloc_scalar(&mut col); - let mul_lo_bits_start = alloc_array(&mut col, 32); - let mul_hi_bits_start = alloc_array(&mut col, 32); - let mul_hi_prefix_start = alloc_array(&mut col, 31); - let mul_carry = alloc_scalar(&mut col); - let mul_carry_bits_start = alloc_array(&mut col, 2); - - let rs1_bits_start = alloc_array(&mut col, 32); - let rs2_bits_start = alloc_array(&mut col, 32); - let rs2_zero_prefix_start = alloc_array(&mut col, 31); - let rs1_abs = alloc_scalar(&mut col); - let rs2_abs = alloc_scalar(&mut col); - let rs1_rs2_sign_and = alloc_scalar(&mut col); - let rs1_sign_rs2_val = alloc_scalar(&mut col); - let rs2_sign_rs1_val = alloc_scalar(&mut col); - - let div_quot = alloc_scalar(&mut col); - let div_rem = alloc_scalar(&mut col); - let div_quot_signed = alloc_scalar(&mut col); - let div_rem_signed = alloc_scalar(&mut col); - let div_quot_carry = alloc_scalar(&mut col); - let div_rem_carry = alloc_scalar(&mut col); - let div_prod = alloc_scalar(&mut col); - let div_divisor = alloc_scalar(&mut col); - let div_quot_bits_start = alloc_array(&mut col, 32); - let div_rem_bits_start = alloc_array(&mut col, 32); - let rs2_is_zero = alloc_scalar(&mut col); - let rs2_nonzero = alloc_scalar(&mut col); - let is_divu_or_remu = alloc_scalar(&mut col); - let divu_by_zero = alloc_scalar(&mut col); - let is_div_or_rem = alloc_scalar(&mut col); - let div_nonzero = alloc_scalar(&mut col); - let rem_nonzero = alloc_scalar(&mut col); - let div_by_zero = alloc_scalar(&mut col); - let rem_by_zero = alloc_scalar(&mut col); - let div_sign = alloc_scalar(&mut col); - let div_rem_check = alloc_scalar(&mut col); - let div_rem_check_signed = alloc_scalar(&mut col); - let ecall_a0_bits_start = alloc_array(&mut col, 32); - let ecall_cycle_prefix_start = alloc_array(&mut col, 31); - let ecall_is_cycle = alloc_scalar(&mut col); - let ecall_print_prefix_start = alloc_array(&mut col, 31); - let ecall_is_print = alloc_scalar(&mut col); - let ecall_halts = alloc_scalar(&mut col); - let halt_effective = alloc_scalar(&mut col); - - let cpu_cols_used = col; - - let (mem_ids, twist_ell_addrs) = derive_mem_ids_and_ell_addrs(mem_layouts)?; - let (table_ids, shout_ell_addrs) = derive_shout_ids_and_ell_addrs(shout_table_ids)?; - - let bus = build_bus_layout_for_instances(m, m_in, chunk_size, shout_ell_addrs, twist_ell_addrs.clone())?; - if cpu_cols_used > bus.bus_base { - return Err(format!( - "RV32 B1 layout: CPU columns end at {cpu_cols_used}, but bus_base={} (need more padding columns before bus tail)", - bus.bus_base - )); - } - - // Determine which twist instance index corresponds to RAM/PROG in the sorted mem_ids order. - let ram_id = RAM_ID.0; - let prog_id = PROG_ID.0; - let ram_twist_idx = mem_ids - .iter() - .position(|&id| id == ram_id) - .ok_or_else(|| format!("mem_layouts missing RAM_ID={ram_id}"))?; - let prog_twist_idx = mem_ids - .iter() - .position(|&id| id == prog_id) - .ok_or_else(|| format!("mem_layouts missing PROG_ID={prog_id}"))?; - - Ok(Rv32B1Layout { - m_in, - m, - chunk_size, - const_one, - pc0, - regs0_start, - pc_final, - regs_final_start, - halted_in, - halted_out, - is_active, - pc_in, - pc_out, - instr_word, - regs_in_start, - regs_out_start, - instr_bits_start, - opcode, - funct3, - funct7, - rd_field, - rs1_field, - rs2_field, - imm12_raw, - imm_i, - imm_s, - imm_u, - imm_b_raw, - imm_b, - imm_j_raw, - imm_j, - is_add, - is_sub, - is_sll, - is_slt, - is_sltu, - is_xor, - is_srl, - is_sra, - is_or, - is_and, - is_mul, - is_mulh, - is_mulhu, - is_mulhsu, - is_div, - is_divu, - is_rem, - is_remu, - is_addi, - is_slti, - is_sltiu, - is_xori, - is_ori, - is_andi, - is_slli, - is_srli, - is_srai, - is_lb, - is_lbu, - is_lh, - is_lhu, - is_lw, - is_sb, - is_sh, - is_sw, - is_amoswap_w, - is_amoadd_w, - is_amoxor_w, - is_amoor_w, - is_amoand_w, - is_lui, - is_auipc, - is_beq, - is_bne, - is_blt, - is_bge, - is_bltu, - is_bgeu, - is_jal, - is_jalr, - is_fence, - is_halt, - br_taken, - br_not_taken, - rs1_sel_start, - rs2_sel_start, - rd_sel_start, - rs1_val, - rs2_val, - alu_out, - mem_rv, - mem_rv_bits_start, - eff_addr, - ram_has_read, - ram_has_write, - ram_wv, - rd_write_val, - rd_write_bits_start, - add_has_lookup, - and_has_lookup, - xor_has_lookup, - or_has_lookup, - sll_has_lookup, - srl_has_lookup, - sra_has_lookup, - slt_has_lookup, - sltu_has_lookup, - lookup_key, - add_a0b0, - mul_lo, - mul_hi, - mul_lo_bits_start, - mul_hi_bits_start, - mul_hi_prefix_start, - mul_carry, - mul_carry_bits_start, - rs1_bits_start, - rs2_bits_start, - rs2_zero_prefix_start, - rs1_abs, - rs2_abs, - rs1_rs2_sign_and, - rs1_sign_rs2_val, - rs2_sign_rs1_val, - div_quot, - div_rem, - div_quot_signed, - div_rem_signed, - div_quot_carry, - div_rem_carry, - div_prod, - div_divisor, - div_quot_bits_start, - div_rem_bits_start, - rs2_is_zero, - rs2_nonzero, - is_divu_or_remu, - divu_by_zero, - is_div_or_rem, - div_nonzero, - rem_nonzero, - div_by_zero, - rem_by_zero, - div_sign, - div_rem_check, - div_rem_check_signed, - ecall_a0_bits_start, - ecall_cycle_prefix_start, - ecall_is_cycle, - ecall_print_prefix_start, - ecall_is_print, - ecall_halts, - halt_effective, - bus, - mem_ids, - table_ids, - ram_twist_idx, - prog_twist_idx, - }) -} diff --git a/crates/neo-memory/src/riscv/ccs/trace.rs b/crates/neo-memory/src/riscv/ccs/trace.rs new file mode 100644 index 00000000..05213de0 --- /dev/null +++ b/crates/neo-memory/src/riscv/ccs/trace.rs @@ -0,0 +1,250 @@ +use neo_ccs::relations::CcsStructure; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +use crate::riscv::exec_table::Rv32ExecTable; +use crate::riscv::trace::{Rv32TraceLayout, Rv32TraceWitness}; + +use super::constraint_builder::{build_r1cs_ccs, Constraint}; + +/// Fixed-width, time-in-rows trace CCS layout. +/// +/// This is a Tier 2.1 trace CCS with fixed columns over time (`t` rows), +/// AIR-like wiring invariants, and a compact subset of ISA semantics guards. +/// +/// Witness layout (column-major trace region): +/// `cell(trace_col, row) = trace_base + trace_col * t + row`. +#[derive(Clone, Debug)] +pub struct Rv32TraceCcsLayout { + pub t: usize, + pub m_in: usize, + pub m: usize, + + // Public scalars. + pub const_one: usize, + pub pc0: usize, + pub pc_final: usize, + pub halted_in: usize, + pub halted_out: usize, + + pub trace_base: usize, + pub trace: Rv32TraceLayout, +} + +impl Rv32TraceCcsLayout { + pub fn new(t: usize) -> Result { + if t == 0 { + return Err("Rv32TraceCcsLayout: t must be >= 1".into()); + } + + let const_one: usize = 0; + let pc0: usize = 1; + let pc_final: usize = 2; + let halted_in: usize = 3; + let halted_out: usize = 4; + let m_in: usize = 5; + + let trace = Rv32TraceLayout::new(); + let trace_base = m_in; + let trace_len = trace + .cols + .checked_mul(t) + .ok_or_else(|| "Rv32TraceCcsLayout: trace_len overflow".to_string())?; + let m = trace_base + .checked_add(trace_len) + .ok_or_else(|| "Rv32TraceCcsLayout: m overflow".to_string())?; + + Ok(Self { + t, + m_in, + m, + const_one, + pc0, + pc_final, + halted_in, + halted_out, + trace_base, + trace, + }) + } + + /// Full witness index for a trace cell. + #[inline] + pub fn cell(&self, trace_col: usize, row: usize) -> usize { + debug_assert!(trace_col < self.trace.cols); + debug_assert!(row < self.t); + self.trace_base + trace_col * self.t + row + } +} + +/// Build the public inputs `x` and witness `w` for the trace CCS from an exec table. +pub fn rv32_trace_ccs_witness_from_exec_table( + layout: &Rv32TraceCcsLayout, + exec: &Rv32ExecTable, +) -> Result<(Vec, Vec), String> { + let wit = Rv32TraceWitness::from_exec_table(&layout.trace, exec)?; + rv32_trace_ccs_witness_from_trace_witness(layout, &wit) +} + +/// Build the public inputs `x` and witness `w` for the trace CCS from a trace witness. +pub fn rv32_trace_ccs_witness_from_trace_witness( + layout: &Rv32TraceCcsLayout, + wit: &Rv32TraceWitness, +) -> Result<(Vec, Vec), String> { + if wit.t != layout.t { + return Err(format!( + "trace CCS witness: t mismatch (wit.t={} layout.t={})", + wit.t, layout.t + )); + } + if wit.cols.len() != layout.trace.cols { + return Err(format!( + "trace CCS witness: width mismatch (wit.cols={} trace.cols={})", + wit.cols.len(), + layout.trace.cols + )); + } + + let mut x = vec![F::ZERO; layout.m_in]; + x[layout.const_one] = F::ONE; + x[layout.pc0] = wit.cols[layout.trace.pc_before][0]; + x[layout.pc_final] = wit.cols[layout.trace.pc_after][layout.t - 1]; + x[layout.halted_in] = wit.cols[layout.trace.halted][0]; + x[layout.halted_out] = wit.cols[layout.trace.halted][layout.t - 1]; + + let mut w = vec![F::ZERO; layout.m - layout.m_in]; + for trace_col in 0..layout.trace.cols { + let col = &wit.cols[trace_col]; + for row in 0..layout.t { + let idx = layout.cell(trace_col, row); + w[idx - layout.m_in] = col[row]; + } + } + + Ok((x, w)) +} + +/// Build the base trace CCS (wiring invariants + partial ISA semantics guards). +pub fn build_rv32_trace_wiring_ccs(layout: &Rv32TraceCcsLayout) -> Result, String> { + build_rv32_trace_wiring_ccs_with_reserved_rows(layout, 0) +} + +pub fn build_rv32_trace_wiring_ccs_with_reserved_rows( + layout: &Rv32TraceCcsLayout, + reserved_rows: usize, +) -> Result, String> { + let one = layout.const_one; + let t = layout.t; + let tr = |c: usize, i: usize| -> usize { layout.cell(c, i) }; + let l = &layout.trace; + + let mut cons: Vec> = Vec::new(); + + // Public bindings. + cons.push(Constraint::terms( + one, + false, + vec![(layout.pc0, F::ONE), (tr(l.pc_before, 0), -F::ONE)], + )); + cons.push(Constraint::terms( + one, + false, + vec![(layout.pc_final, F::ONE), (tr(l.pc_after, t - 1), -F::ONE)], + )); + cons.push(Constraint::terms( + one, + false, + vec![(layout.halted_in, F::ONE), (tr(l.halted, 0), -F::ONE)], + )); + cons.push(Constraint::terms( + one, + false, + vec![(layout.halted_out, F::ONE), (tr(l.halted, t - 1), -F::ONE)], + )); + // Execution anchor: the first trace row must be active. + cons.push(Constraint::terms( + one, + false, + vec![(tr(l.active, 0), F::ONE), (one, -F::ONE)], + )); + + for i in 0..t { + let _halted = tr(l.halted, i); + let shout_has_lookup = tr(l.shout_has_lookup, i); + + // Canonical AIR-style one-column. + cons.push(Constraint::terms( + one, + false, + vec![(tr(l.one, i), F::ONE), (one, -F::ONE)], + )); + + // Booleanity and inactive-row quiescence are enforced by WB/WP sidecar stages. + + // Shout padding: (1 - has_lookup) * val == 0. + cons.push(Constraint::terms( + shout_has_lookup, + true, + vec![(tr(l.shout_val, i), F::ONE)], + )); + cons.push(Constraint::terms( + shout_has_lookup, + true, + vec![(tr(l.shout_lhs, i), F::ONE)], + )); + cons.push(Constraint::terms( + shout_has_lookup, + true, + vec![(tr(l.shout_rhs, i), F::ONE)], + )); + } + + for i in 0..t.saturating_sub(1) { + // pc_after[i] == pc_before[i+1] + cons.push(Constraint::terms( + one, + false, + vec![(tr(l.pc_after, i), F::ONE), (tr(l.pc_before, i + 1), -F::ONE)], + )); + + // cycle[i+1] == cycle[i] + 1 + cons.push(Constraint::terms( + one, + false, + vec![(tr(l.cycle, i + 1), F::ONE), (tr(l.cycle, i), -F::ONE), (one, -F::ONE)], + )); + + // Once inactive, remain inactive: active[i+1] * (1 - active[i]) == 0 + cons.push(Constraint::terms( + tr(l.active, i + 1), + false, + vec![(one, F::ONE), (tr(l.active, i), -F::ONE)], + )); + + // Once halted, remain halted: halted[i] * (1 - halted[i+1]) == 0 + cons.push(Constraint::terms( + tr(l.halted, i), + false, + vec![(one, F::ONE), (tr(l.halted, i + 1), -F::ONE)], + )); + + // Halted tail quiescence: + // once halted, the next row must be inactive and keep the same pc_after. + cons.push(Constraint::terms( + tr(l.halted, i), + false, + vec![(tr(l.active, i + 1), F::ONE)], + )); + cons.push(Constraint::terms( + tr(l.halted, i), + false, + vec![(tr(l.pc_after, i), F::ONE), (tr(l.pc_after, i + 1), -F::ONE)], + )); + } + + let n = cons + .len() + .checked_add(reserved_rows) + .ok_or_else(|| "RV32 trace CCS: n overflow".to_string())?; + build_r1cs_ccs(&cons, n, layout.m, layout.const_one) +} diff --git a/crates/neo-memory/src/riscv/ccs/witness.rs b/crates/neo-memory/src/riscv/ccs/witness.rs deleted file mode 100644 index a4f673a7..00000000 --- a/crates/neo-memory/src/riscv/ccs/witness.rs +++ /dev/null @@ -1,1249 +0,0 @@ -use p3_field::{PrimeCharacteristicRing, PrimeField64}; -use p3_goldilocks::Goldilocks as F; - -use neo_vm_trace::{StepTrace, TwistOpKind}; - -use crate::riscv::lookups::{ - decode_instruction, BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode, JOLT_CYCLE_TRACK_ECALL_NUM, - JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID, -}; - -use super::constants::{ - ADD_TABLE_ID, AND_TABLE_ID, EQ_TABLE_ID, NEQ_TABLE_ID, OR_TABLE_ID, SLL_TABLE_ID, SLTU_TABLE_ID, SLT_TABLE_ID, - SRA_TABLE_ID, SRL_TABLE_ID, SUB_TABLE_ID, XOR_TABLE_ID, -}; -use super::Rv32B1Layout; - -#[inline] -fn set_bus_cell(z: &mut [Ff], layout: &Rv32B1Layout, bus_col: usize, j: usize, val: Ff) { - let col = layout.bus.bus_cell(bus_col, j); - z[col] = val; -} - -#[inline] -fn write_bus_u64_bits( - z: &mut [Ff], - layout: &Rv32B1Layout, - start_bus_col: usize, - len: usize, - j: usize, - mut value: u64, -) { - assert!( - len <= 64, - "RV32 B1 witness: bus bit range too large for u64 writer (len={len})" - ); - for k in 0..len { - let bit = (value & 1) as u64; - value >>= 1; - set_bus_cell( - z, - layout, - start_bus_col + k, - j, - if bit == 1 { Ff::ONE } else { Ff::ZERO }, - ); - } -} - -fn set_ecall_helpers(z: &mut [F], layout: &Rv32B1Layout, j: usize, a0_u64: u64, is_halt: bool) -> Result<(), String> { - let a0_u32 = u32::try_from(a0_u64).map_err(|_| format!("RV32 B1: a0 value does not fit in u32: {a0_u64}"))?; - - for bit in 0..32 { - z[layout.ecall_a0_bit(bit, j)] = if ((a0_u32 >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; - } - - let cycle_const = JOLT_CYCLE_TRACK_ECALL_NUM; - let print_const = JOLT_PRINT_ECALL_NUM; - - let mut cycle_prefix = if ((a0_u32 ^ cycle_const) & 1) == 0 { 1u32 } else { 0u32 }; - z[layout.ecall_cycle_prefix(0, j)] = if cycle_prefix == 1 { F::ONE } else { F::ZERO }; - for k in 1..31 { - let bit_match = ((a0_u32 >> k) ^ (cycle_const >> k)) & 1; - cycle_prefix &= 1u32 ^ bit_match; - z[layout.ecall_cycle_prefix(k, j)] = if cycle_prefix == 1 { F::ONE } else { F::ZERO }; - } - let cycle_match31 = (((a0_u32 >> 31) ^ (cycle_const >> 31)) & 1) == 0; - let is_cycle = cycle_prefix == 1 && cycle_match31; - z[layout.ecall_is_cycle(j)] = if is_cycle { F::ONE } else { F::ZERO }; - - let mut print_prefix = if ((a0_u32 ^ print_const) & 1) == 0 { 1u32 } else { 0u32 }; - z[layout.ecall_print_prefix(0, j)] = if print_prefix == 1 { F::ONE } else { F::ZERO }; - for k in 1..31 { - let bit_match = ((a0_u32 >> k) ^ (print_const >> k)) & 1; - print_prefix &= 1u32 ^ bit_match; - z[layout.ecall_print_prefix(k, j)] = if print_prefix == 1 { F::ONE } else { F::ZERO }; - } - let print_match31 = (((a0_u32 >> 31) ^ (print_const >> 31)) & 1) == 0; - let is_print = print_prefix == 1 && print_match31; - z[layout.ecall_is_print(j)] = if is_print { F::ONE } else { F::ZERO }; - - let ecall_halts = !(is_cycle || is_print); - z[layout.ecall_halts(j)] = if ecall_halts { F::ONE } else { F::ZERO }; - z[layout.halt_effective(j)] = if is_halt && ecall_halts { F::ONE } else { F::ZERO }; - - Ok(()) -} - -/// Build a CPU witness vector `z` for shared-bus mode. -/// -/// In shared-bus mode, `R1csCpu` overwrites the reserved bus tail from `StepTrace` events, so this -/// witness builder leaves the bus region at its zero default and only populates CPU columns. -pub fn rv32_b1_chunk_to_witness(layout: Rv32B1Layout) -> Box]) -> Vec + Send + Sync> { - Box::new(move |chunk: &[StepTrace]| { - rv32_b1_chunk_to_witness_checked(&layout, chunk).unwrap_or_else(|e| { - panic!("RV32 B1 witness build failed: {e}"); - }) - }) -} - -/// Build a full witness vector `z`, including the bus tail (standalone/debug/test use). -pub fn rv32_b1_chunk_to_full_witness( - layout: Rv32B1Layout, -) -> Box]) -> Vec + Send + Sync> { - Box::new(move |chunk: &[StepTrace]| { - rv32_b1_chunk_to_full_witness_checked(&layout, chunk).unwrap_or_else(|e| { - panic!("RV32 B1 full witness build failed: {e}"); - }) - }) -} - -pub fn rv32_b1_chunk_to_witness_checked( - layout: &Rv32B1Layout, - chunk: &[StepTrace], -) -> Result, String> { - rv32_b1_chunk_to_witness_internal(layout, chunk, /*fill_bus=*/ false) -} - -pub fn rv32_b1_chunk_to_full_witness_checked( - layout: &Rv32B1Layout, - chunk: &[StepTrace], -) -> Result, String> { - rv32_b1_chunk_to_witness_internal(layout, chunk, /*fill_bus=*/ true) -} - -fn rv32_b1_chunk_to_witness_internal( - layout: &Rv32B1Layout, - chunk: &[StepTrace], - fill_bus: bool, -) -> Result, String> { - let mut z = vec![F::ZERO; layout.m]; - - z[layout.const_one] = F::ONE; - - let add_shout_idx = layout - .shout_idx(ADD_TABLE_ID) - .map_err(|e| format!("RV32 B1: {e}"))?; - let add_lane = &layout.bus.shout_cols[add_shout_idx].lanes[0]; - let prog_lane = &layout.bus.twist_cols[layout.prog_twist_idx].lanes[0]; - let ram_lane = &layout.bus.twist_cols[layout.ram_twist_idx].lanes[0]; - - // Carry the architectural state forward through padding rows. - // Initialize from the chunk's start state so fully-inactive chunks are well-defined. - let mut carried_pc = 0u64; - let mut carried_regs = [0u64; 32]; - - if let Some(first) = chunk.first() { - z[layout.pc0] = F::from_u64(first.pc_before); - for r in 0..32 { - z[layout.regs0_start + r] = F::from_u64(first.regs_before[r]); - carried_regs[r] = first.regs_before[r]; - } - z[layout.regs0_start] = F::ZERO; - carried_regs[0] = 0; - carried_pc = first.pc_before; - } - - for j in 0..layout.chunk_size { - if j >= chunk.len() { - z[layout.is_active(j)] = F::ZERO; - - z[layout.pc_in(j)] = F::from_u64(carried_pc); - z[layout.pc_out(j)] = F::from_u64(carried_pc); - for r in 0..32 { - z[layout.reg_in(r, j)] = F::from_u64(carried_regs[r]); - z[layout.reg_out(r, j)] = F::from_u64(carried_regs[r]); - } - // Columns constrained independently of `is_active` must be set consistently on padding rows. - for bit in 0..32 { - z[layout.rd_write_bit(bit, j)] = F::ZERO; - z[layout.mem_rv_bit(bit, j)] = F::ZERO; - z[layout.mul_lo_bit(bit, j)] = F::ZERO; - z[layout.mul_hi_bit(bit, j)] = F::ZERO; - z[layout.div_quot_bit(bit, j)] = F::ZERO; - z[layout.div_rem_bit(bit, j)] = F::ZERO; - z[layout.rs1_bit(bit, j)] = F::ZERO; - z[layout.rs2_bit(bit, j)] = F::ZERO; - } - for bit in 0..2 { - z[layout.mul_carry_bit(bit, j)] = F::ZERO; - } - for k in 0..31 { - z[layout.rs2_zero_prefix(k, j)] = F::ONE; - } - for k in 0..31 { - z[layout.mul_hi_prefix(k, j)] = F::ZERO; - } - z[layout.rs2_is_zero(j)] = F::ONE; - z[layout.rs2_nonzero(j)] = F::ZERO; - set_ecall_helpers(&mut z, layout, j, carried_regs[10], false)?; - continue; - } - let step = &chunk[j]; - - // A row is active iff it contains exactly one PROG_ID read (B1 instruction fetch). - // Padding rows contain no Twist/Shout events and are treated as inactive. - let mut prog_read: Option<(u64, u64)> = None; - let mut ram_read: Option<(u64, u64)> = None; - let mut ram_write: Option<(u64, u64)> = None; - for ev in &step.twist_events { - if ev.twist_id == PROG_ID { - match ev.kind { - TwistOpKind::Read => { - if prog_read.replace((ev.addr, ev.value)).is_some() { - return Err(format!( - "RV32 B1: multiple PROG_ID reads in one step at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - } - TwistOpKind::Write => { - return Err(format!( - "RV32 B1: unexpected PROG_ID write at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - } - } else if ev.twist_id == RAM_ID { - match ev.kind { - TwistOpKind::Read => { - if ram_read.replace((ev.addr, ev.value)).is_some() { - return Err(format!( - "RV32 B1: multiple RAM reads in one step at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - } - TwistOpKind::Write => { - if ram_write.replace((ev.addr, ev.value)).is_some() { - return Err(format!( - "RV32 B1: multiple RAM writes in one step at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - } - } - } else { - return Err(format!( - "RV32 B1: unexpected twist_id={} at pc={:#x} (chunk j={j})", - ev.twist_id.0, step.pc_before - )); - } - } - - if prog_read.is_none() { - if !step.twist_events.is_empty() || !step.shout_events.is_empty() { - return Err(format!( - "RV32 B1: non-empty events in inactive row at step cycle={} (chunk j={j})", - step.cycle - )); - } - - z[layout.is_active(j)] = F::ZERO; - z[layout.pc_in(j)] = F::from_u64(carried_pc); - z[layout.pc_out(j)] = F::from_u64(carried_pc); - for r in 0..32 { - z[layout.reg_in(r, j)] = F::from_u64(carried_regs[r]); - z[layout.reg_out(r, j)] = F::from_u64(carried_regs[r]); - } - // Columns constrained independently of `is_active` must be set consistently on padding rows. - for bit in 0..32 { - z[layout.rd_write_bit(bit, j)] = F::ZERO; - z[layout.mem_rv_bit(bit, j)] = F::ZERO; - z[layout.mul_lo_bit(bit, j)] = F::ZERO; - z[layout.mul_hi_bit(bit, j)] = F::ZERO; - z[layout.div_quot_bit(bit, j)] = F::ZERO; - z[layout.div_rem_bit(bit, j)] = F::ZERO; - z[layout.rs1_bit(bit, j)] = F::ZERO; - z[layout.rs2_bit(bit, j)] = F::ZERO; - } - for bit in 0..2 { - z[layout.mul_carry_bit(bit, j)] = F::ZERO; - } - for k in 0..31 { - z[layout.rs2_zero_prefix(k, j)] = F::ONE; - } - for k in 0..31 { - z[layout.mul_hi_prefix(k, j)] = F::ZERO; - } - z[layout.rs2_is_zero(j)] = F::ONE; - z[layout.rs2_nonzero(j)] = F::ZERO; - set_ecall_helpers(&mut z, layout, j, carried_regs[10], false)?; - continue; - } - - z[layout.is_active(j)] = F::ONE; - z[layout.pc_in(j)] = F::from_u64(step.pc_before); - z[layout.pc_out(j)] = F::from_u64(step.pc_after); - - // Registers. - for r in 0..32 { - z[layout.reg_in(r, j)] = F::from_u64(step.regs_before[r]); - z[layout.reg_out(r, j)] = F::from_u64(step.regs_after[r]); - } - - carried_pc = step.pc_after; - carried_regs.copy_from_slice(&step.regs_after); - carried_regs[0] = 0; - - // Instruction word: read from PROG_ID Twist event (commitment-bound source). - let (prog_addr, prog_value) = prog_read.expect("checked prog_read is present"); - if prog_addr != step.pc_before { - return Err(format!( - "RV32 B1: PROG_ID read addr mismatch at pc={:#x} (chunk j={j}): read_addr={:#x}", - step.pc_before, prog_addr - )); - } - let instr_word_u32 = u32::try_from(prog_value).map_err(|_| { - format!( - "RV32 B1: PROG_ID read value does not fit in u32 at pc={:#x}: value={:#x}", - step.pc_before, prog_value - ) - })?; - z[layout.instr_word(j)] = F::from_u64(instr_word_u32 as u64); - if fill_bus { - set_bus_cell(&mut z, layout, prog_lane.has_read, j, F::ONE); - set_bus_cell(&mut z, layout, prog_lane.has_write, j, F::ZERO); - write_bus_u64_bits( - &mut z, - layout, - prog_lane.ra_bits.start, - prog_lane.ra_bits.end - prog_lane.ra_bits.start, - j, - prog_addr, - ); - set_bus_cell(&mut z, layout, prog_lane.rv, j, F::from_u64(prog_value)); - set_bus_cell(&mut z, layout, prog_lane.wv, j, F::ZERO); - set_bus_cell(&mut z, layout, prog_lane.inc, j, F::ZERO); - } - - // Bits. - for i in 0..32 { - z[layout.instr_bit(i, j)] = if ((instr_word_u32 >> i) & 1) == 1 { - F::ONE - } else { - F::ZERO - }; - } - - // Decode fields. - let opcode = instr_word_u32 & 0x7f; - let rd = (instr_word_u32 >> 7) & 0x1f; - let funct3 = (instr_word_u32 >> 12) & 0x7; - let rs1 = (instr_word_u32 >> 15) & 0x1f; - let rs2 = (instr_word_u32 >> 20) & 0x1f; - let funct7 = (instr_word_u32 >> 25) & 0x7f; - - z[layout.opcode(j)] = F::from_u64(opcode as u64); - z[layout.funct3(j)] = F::from_u64(funct3 as u64); - z[layout.funct7(j)] = F::from_u64(funct7 as u64); - z[layout.rd_field(j)] = F::from_u64(rd as u64); - z[layout.rs1_field(j)] = F::from_u64(rs1 as u64); - z[layout.rs2_field(j)] = F::from_u64(rs2 as u64); - - // Helpers for immediate representations: - // - `sx_u32` matches the CCS u32-style encoding used for imm_i / imm_s. - // - `from_i32` matches the CCS signed encoding used for imm_b / imm_j. - let sx_u32 = |x: i32| x as u32 as u64; - let from_i32 = |v: i32| -> F { - if v >= 0 { - F::from_u64(v as u64) - } else { - -F::from_u64((-v) as u64) - } - }; - - // Immediate raw fields. - let imm12_raw = ((instr_word_u32 >> 20) & 0xfff) as u32; - z[layout.imm12_raw(j)] = F::from_u64(imm12_raw as u64); - - // I-type immediate (sign-extended 12-bit). - let imm_i = ((imm12_raw as i32) << 20) >> 20; - z[layout.imm_i(j)] = F::from_u64(sx_u32(imm_i)); - - // S-type immediate (sign-extended 12-bit). - let imm_s_raw = (((instr_word_u32 >> 7) & 0x1f) | (((instr_word_u32 >> 25) & 0x7f) << 5)) as u32; - let imm_s = ((imm_s_raw as i32) << 20) >> 20; - z[layout.imm_s(j)] = F::from_u64(sx_u32(imm_s)); - - // U-type immediate (upper 20 bits). - let imm_u = (instr_word_u32 & 0xfffff000) as u64; - z[layout.imm_u(j)] = F::from_u64(imm_u); - - // B-type immediate raw bits (before sign extension). - let imm_b_raw = (((instr_word_u32 >> 7) & 0x1) << 11) - | (((instr_word_u32 >> 8) & 0xf) << 1) - | (((instr_word_u32 >> 25) & 0x3f) << 5) - | (((instr_word_u32 >> 31) & 0x1) << 12); - z[layout.imm_b_raw(j)] = F::from_u64(imm_b_raw as u64); - let imm_b = ((imm_b_raw as i32) << 19) >> 19; - z[layout.imm_b(j)] = from_i32(imm_b); - - // J-type immediate raw bits (before sign extension). - let imm_j_raw = (((instr_word_u32 >> 21) & 0x3ff) << 1) - | (((instr_word_u32 >> 20) & 0x1) << 11) - | (((instr_word_u32 >> 12) & 0xff) << 12) - | (((instr_word_u32 >> 31) & 0x1) << 20); - z[layout.imm_j_raw(j)] = F::from_u64(imm_j_raw as u64); - let imm_j = ((imm_j_raw as i32) << 11) >> 11; - z[layout.imm_j(j)] = from_i32(imm_j); - - // One-hot flags: use the shared decoder as the single source of truth. - let decoded = decode_instruction(instr_word_u32) - .map_err(|e| format!("RV32 B1: decode failed at pc={:#x}: {e}", step.pc_before))?; - - let mut is_add = false; - let mut is_sub = false; - let mut is_sll = false; - let mut is_slt = false; - let mut is_sltu = false; - let mut is_xor = false; - let mut is_srl = false; - let mut is_sra = false; - let mut is_or = false; - let mut is_and = false; - let mut is_mul = false; - let mut is_mulh = false; - let mut is_mulhu = false; - let mut is_mulhsu = false; - let mut is_div = false; - let mut is_divu = false; - let mut is_rem = false; - let mut is_remu = false; - - let mut is_addi = false; - let mut is_slti = false; - let mut is_sltiu = false; - let mut is_xori = false; - let mut is_ori = false; - let mut is_andi = false; - let mut is_slli = false; - let mut is_srli = false; - let mut is_srai = false; - - let mut is_lb = false; - let mut is_lbu = false; - let mut is_lh = false; - let mut is_lhu = false; - let mut is_lw = false; - let mut is_sb = false; - let mut is_sh = false; - let mut is_sw = false; - let mut is_amoswap_w = false; - let mut is_amoadd_w = false; - let mut is_amoxor_w = false; - let mut is_amoor_w = false; - let mut is_amoand_w = false; - let mut is_lui = false; - let mut is_auipc = false; - let mut is_beq = false; - let mut is_bne = false; - let mut is_blt = false; - let mut is_bge = false; - let mut is_bltu = false; - let mut is_bgeu = false; - let mut is_jal = false; - let mut is_jalr = false; - let mut is_fence = false; - let mut is_halt = false; - - match decoded { - RiscvInstruction::RAlu { op, .. } => match op { - RiscvOpcode::Add => is_add = true, - RiscvOpcode::Sub => is_sub = true, - RiscvOpcode::Sll => is_sll = true, - RiscvOpcode::Slt => is_slt = true, - RiscvOpcode::Sltu => is_sltu = true, - RiscvOpcode::Xor => is_xor = true, - RiscvOpcode::Srl => is_srl = true, - RiscvOpcode::Sra => is_sra = true, - RiscvOpcode::Or => is_or = true, - RiscvOpcode::And => is_and = true, - RiscvOpcode::Mul => is_mul = true, - RiscvOpcode::Mulh => is_mulh = true, - RiscvOpcode::Mulhu => is_mulhu = true, - RiscvOpcode::Mulhsu => is_mulhsu = true, - RiscvOpcode::Div => is_div = true, - RiscvOpcode::Divu => is_divu = true, - RiscvOpcode::Rem => is_rem = true, - RiscvOpcode::Remu => is_remu = true, - _ => {} - }, - RiscvInstruction::IAlu { op, .. } => match op { - RiscvOpcode::Add => is_addi = true, - RiscvOpcode::Slt => is_slti = true, - RiscvOpcode::Sltu => is_sltiu = true, - RiscvOpcode::Xor => is_xori = true, - RiscvOpcode::Or => is_ori = true, - RiscvOpcode::And => is_andi = true, - RiscvOpcode::Sll => is_slli = true, - RiscvOpcode::Srl => is_srli = true, - RiscvOpcode::Sra => is_srai = true, - _ => {} - }, - RiscvInstruction::Load { op, .. } => match op { - RiscvMemOp::Lb => is_lb = true, - RiscvMemOp::Lbu => is_lbu = true, - RiscvMemOp::Lh => is_lh = true, - RiscvMemOp::Lhu => is_lhu = true, - RiscvMemOp::Lw => is_lw = true, - _ => {} - }, - RiscvInstruction::Store { op, .. } => match op { - RiscvMemOp::Sb => is_sb = true, - RiscvMemOp::Sh => is_sh = true, - RiscvMemOp::Sw => is_sw = true, - _ => {} - }, - RiscvInstruction::Amo { op, .. } => match op { - RiscvMemOp::AmoswapW => is_amoswap_w = true, - RiscvMemOp::AmoaddW => is_amoadd_w = true, - RiscvMemOp::AmoxorW => is_amoxor_w = true, - RiscvMemOp::AmoorW => is_amoor_w = true, - RiscvMemOp::AmoandW => is_amoand_w = true, - _ => {} - }, - RiscvInstruction::Lui { .. } => is_lui = true, - RiscvInstruction::Auipc { .. } => is_auipc = true, - RiscvInstruction::Branch { cond, .. } => match cond { - BranchCondition::Eq => is_beq = true, - BranchCondition::Ne => is_bne = true, - BranchCondition::Lt => is_blt = true, - BranchCondition::Ge => is_bge = true, - BranchCondition::Ltu => is_bltu = true, - BranchCondition::Geu => is_bgeu = true, - }, - RiscvInstruction::Jal { .. } => is_jal = true, - RiscvInstruction::Jalr { .. } => is_jalr = true, - RiscvInstruction::Fence { .. } => is_fence = true, - RiscvInstruction::Halt => is_halt = true, - _ => {} - } - - // Reject unsupported instructions. - if !(is_add - || is_sub - || is_sll - || is_slt - || is_sltu - || is_xor - || is_srl - || is_sra - || is_or - || is_and - || is_mul - || is_mulh - || is_mulhu - || is_mulhsu - || is_div - || is_divu - || is_rem - || is_remu - || is_addi - || is_slti - || is_sltiu - || is_xori - || is_ori - || is_andi - || is_slli - || is_srli - || is_srai - || is_lb - || is_lbu - || is_lh - || is_lhu - || is_lw - || is_sb - || is_sh - || is_sw - || is_amoswap_w - || is_amoadd_w - || is_amoxor_w - || is_amoor_w - || is_amoand_w - || is_lui - || is_auipc - || is_beq - || is_bne - || is_blt - || is_bge - || is_bltu - || is_bgeu - || is_jal - || is_jalr - || is_fence - || is_halt) - { - return Err(format!( - "RV32 B1: unsupported instruction at pc={:#x}: word={:#x}", - step.pc_before, instr_word_u32 - )); - } - - z[layout.is_add(j)] = if is_add { F::ONE } else { F::ZERO }; - z[layout.is_sub(j)] = if is_sub { F::ONE } else { F::ZERO }; - z[layout.is_sll(j)] = if is_sll { F::ONE } else { F::ZERO }; - z[layout.is_slt(j)] = if is_slt { F::ONE } else { F::ZERO }; - z[layout.is_sltu(j)] = if is_sltu { F::ONE } else { F::ZERO }; - z[layout.is_xor(j)] = if is_xor { F::ONE } else { F::ZERO }; - z[layout.is_srl(j)] = if is_srl { F::ONE } else { F::ZERO }; - z[layout.is_sra(j)] = if is_sra { F::ONE } else { F::ZERO }; - z[layout.is_or(j)] = if is_or { F::ONE } else { F::ZERO }; - z[layout.is_and(j)] = if is_and { F::ONE } else { F::ZERO }; - z[layout.is_mul(j)] = if is_mul { F::ONE } else { F::ZERO }; - z[layout.is_mulh(j)] = if is_mulh { F::ONE } else { F::ZERO }; - z[layout.is_mulhu(j)] = if is_mulhu { F::ONE } else { F::ZERO }; - z[layout.is_mulhsu(j)] = if is_mulhsu { F::ONE } else { F::ZERO }; - z[layout.is_div(j)] = if is_div { F::ONE } else { F::ZERO }; - z[layout.is_divu(j)] = if is_divu { F::ONE } else { F::ZERO }; - z[layout.is_rem(j)] = if is_rem { F::ONE } else { F::ZERO }; - z[layout.is_remu(j)] = if is_remu { F::ONE } else { F::ZERO }; - z[layout.is_addi(j)] = if is_addi { F::ONE } else { F::ZERO }; - z[layout.is_slti(j)] = if is_slti { F::ONE } else { F::ZERO }; - z[layout.is_sltiu(j)] = if is_sltiu { F::ONE } else { F::ZERO }; - z[layout.is_xori(j)] = if is_xori { F::ONE } else { F::ZERO }; - z[layout.is_ori(j)] = if is_ori { F::ONE } else { F::ZERO }; - z[layout.is_andi(j)] = if is_andi { F::ONE } else { F::ZERO }; - z[layout.is_slli(j)] = if is_slli { F::ONE } else { F::ZERO }; - z[layout.is_srli(j)] = if is_srli { F::ONE } else { F::ZERO }; - z[layout.is_srai(j)] = if is_srai { F::ONE } else { F::ZERO }; - z[layout.is_lb(j)] = if is_lb { F::ONE } else { F::ZERO }; - z[layout.is_lbu(j)] = if is_lbu { F::ONE } else { F::ZERO }; - z[layout.is_lh(j)] = if is_lh { F::ONE } else { F::ZERO }; - z[layout.is_lhu(j)] = if is_lhu { F::ONE } else { F::ZERO }; - z[layout.is_lw(j)] = if is_lw { F::ONE } else { F::ZERO }; - z[layout.is_sb(j)] = if is_sb { F::ONE } else { F::ZERO }; - z[layout.is_sh(j)] = if is_sh { F::ONE } else { F::ZERO }; - z[layout.is_sw(j)] = if is_sw { F::ONE } else { F::ZERO }; - z[layout.is_amoswap_w(j)] = if is_amoswap_w { F::ONE } else { F::ZERO }; - z[layout.is_amoadd_w(j)] = if is_amoadd_w { F::ONE } else { F::ZERO }; - z[layout.is_amoxor_w(j)] = if is_amoxor_w { F::ONE } else { F::ZERO }; - z[layout.is_amoor_w(j)] = if is_amoor_w { F::ONE } else { F::ZERO }; - z[layout.is_amoand_w(j)] = if is_amoand_w { F::ONE } else { F::ZERO }; - z[layout.is_lui(j)] = if is_lui { F::ONE } else { F::ZERO }; - z[layout.is_auipc(j)] = if is_auipc { F::ONE } else { F::ZERO }; - z[layout.is_beq(j)] = if is_beq { F::ONE } else { F::ZERO }; - z[layout.is_bne(j)] = if is_bne { F::ONE } else { F::ZERO }; - z[layout.is_blt(j)] = if is_blt { F::ONE } else { F::ZERO }; - z[layout.is_bge(j)] = if is_bge { F::ONE } else { F::ZERO }; - z[layout.is_bltu(j)] = if is_bltu { F::ONE } else { F::ZERO }; - z[layout.is_bgeu(j)] = if is_bgeu { F::ONE } else { F::ZERO }; - z[layout.is_jal(j)] = if is_jal { F::ONE } else { F::ZERO }; - z[layout.is_jalr(j)] = if is_jalr { F::ONE } else { F::ZERO }; - z[layout.is_fence(j)] = if is_fence { F::ONE } else { F::ZERO }; - z[layout.is_halt(j)] = if is_halt { F::ONE } else { F::ZERO }; - set_ecall_helpers(&mut z, layout, j, step.regs_before[10], is_halt)?; - - // One-hot register selectors. - let rs1_idx = rs1 as usize; - let rs2_idx = rs2 as usize; - let rd_idx = rd as usize; - z[layout.rs1_sel(rs1_idx, j)] = F::ONE; - z[layout.rs2_sel(rs2_idx, j)] = F::ONE; - - // rd_sel: writes set rd_sel[rd] = 1, non-writes set rd_sel[0] = 1 (x0). - let writes_rd = is_add - || is_sub - || is_sll - || is_slt - || is_sltu - || is_xor - || is_srl - || is_sra - || is_or - || is_and - || is_mul - || is_mulh - || is_mulhu - || is_mulhsu - || is_div - || is_divu - || is_rem - || is_remu - || is_addi - || is_slti - || is_sltiu - || is_xori - || is_ori - || is_andi - || is_slli - || is_srli - || is_srai - || is_lb - || is_lbu - || is_lh - || is_lhu - || is_lw - || is_amoswap_w - || is_amoadd_w - || is_amoxor_w - || is_amoor_w - || is_amoand_w - || is_lui - || is_auipc - || is_jal - || is_jalr; - if writes_rd { - z[layout.rd_sel(rd_idx, j)] = F::ONE; - } else { - z[layout.rd_sel(0, j)] = F::ONE; - } - - // Selected operand values. - let rs1_u32 = u32::try_from(step.regs_before[rs1_idx]) - .map_err(|_| format!("RV32 B1: rs1 value does not fit in u32 at pc={:#x}", step.pc_before))?; - let rs2_u32 = u32::try_from(step.regs_before[rs2_idx]) - .map_err(|_| format!("RV32 B1: rs2 value does not fit in u32 at pc={:#x}", step.pc_before))?; - let rs1_u64 = rs1_u32 as u64; - let rs2_u64 = rs2_u32 as u64; - z[layout.rs1_val(j)] = F::from_u64(rs1_u64); - z[layout.rs2_val(j)] = F::from_u64(rs2_u64); - - // Helpers used by in-circuit RV32M constraints. - let rs1_sign = (rs1_u32 >> 31) & 1; - let rs2_sign = (rs2_u32 >> 31) & 1; - let rs1_abs = if rs1_sign == 0 { rs1_u64 } else { (1u64 << 32) - rs1_u64 }; - let rs2_abs = if rs2_sign == 0 { rs2_u64 } else { (1u64 << 32) - rs2_u64 }; - z[layout.rs1_abs(j)] = F::from_u64(rs1_abs); - z[layout.rs2_abs(j)] = F::from_u64(rs2_abs); - z[layout.rs1_rs2_sign_and(j)] = F::from_u64((rs1_sign & rs2_sign) as u64); - z[layout.rs1_sign_rs2_val(j)] = F::from_u64((rs1_sign as u64) * rs2_u64); - z[layout.rs2_sign_rs1_val(j)] = F::from_u64((rs2_sign as u64) * rs1_u64); - - // MUL product (always computed; constraints use it even when `is_mul=0`). - let mul_prod = (rs1_u64 as u128) * (rs2_u64 as u128); - let mul_lo = (mul_prod & 0xffff_ffff) as u64; - let mul_hi = ((mul_prod >> 32) & 0xffff_ffff) as u64; - z[layout.mul_lo(j)] = F::from_u64(mul_lo); - z[layout.mul_hi(j)] = F::from_u64(mul_hi); - - // Bit decomposition for rs1/rs2 (used for sign, abs, and zero checks). - for bit in 0..32 { - z[layout.rs1_bit(bit, j)] = if ((rs1_u32 >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; - z[layout.rs2_bit(bit, j)] = if ((rs2_u32 >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; - } - let mut prefix = F::ONE - z[layout.rs2_bit(0, j)]; - z[layout.rs2_zero_prefix(0, j)] = prefix; - for k in 1..31 { - prefix *= F::ONE - z[layout.rs2_bit(k, j)]; - z[layout.rs2_zero_prefix(k, j)] = prefix; - } - z[layout.rs2_is_zero(j)] = prefix * (F::ONE - z[layout.rs2_bit(31, j)]); - z[layout.rs2_nonzero(j)] = F::ONE - z[layout.rs2_is_zero(j)]; - - // DIV/REM helpers (unsigned + signed): quotient/remainder decomposition and remainder < divisor check. - let is_divu_or_remu = is_divu || is_remu; - let is_div_or_rem = is_div || is_rem; - let rs2_is_zero = rs2_u32 == 0; - z[layout.is_divu_or_remu(j)] = if is_divu_or_remu { F::ONE } else { F::ZERO }; - z[layout.is_div_or_rem(j)] = if is_div_or_rem { F::ONE } else { F::ZERO }; - - let do_rem_check = is_divu_or_remu && !rs2_is_zero; - let do_rem_check_signed = is_div_or_rem && !rs2_is_zero; - z[layout.div_rem_check(j)] = if do_rem_check { F::ONE } else { F::ZERO }; - z[layout.div_rem_check_signed(j)] = if do_rem_check_signed { F::ONE } else { F::ZERO }; - z[layout.divu_by_zero(j)] = if is_divu && rs2_is_zero { F::ONE } else { F::ZERO }; - z[layout.div_by_zero(j)] = if is_div && rs2_is_zero { F::ONE } else { F::ZERO }; - z[layout.div_nonzero(j)] = if is_div && !rs2_is_zero { F::ONE } else { F::ZERO }; - z[layout.rem_by_zero(j)] = if is_rem && rs2_is_zero { F::ONE } else { F::ZERO }; - z[layout.rem_nonzero(j)] = if is_rem && !rs2_is_zero { F::ONE } else { F::ZERO }; - - let (div_quot, div_rem, div_divisor) = if is_divu_or_remu { - if rs2_is_zero { - (u32::MAX as u64, rs1_u64, rs2_u64) - } else { - (rs1_u64 / rs2_u64, rs1_u64 % rs2_u64, rs2_u64) - } - } else if is_div_or_rem { - if rs2_is_zero { - (0u64, rs1_abs, rs2_abs) - } else { - (rs1_abs / rs2_abs, rs1_abs % rs2_abs, rs2_abs) - } - } else { - (0u64, 0u64, 0u64) - }; - z[layout.div_quot(j)] = F::from_u64(div_quot); - z[layout.div_rem(j)] = F::from_u64(div_rem); - z[layout.div_divisor(j)] = F::from_u64(div_divisor); - z[layout.div_prod(j)] = F::from_u64(((div_divisor as u128) * (div_quot as u128)) as u64); - - let div_sign = (rs1_sign ^ rs2_sign) as u64; - z[layout.div_sign(j)] = F::from_u64(div_sign); - let (div_quot_signed, div_quot_carry) = if div_sign == 0 { - (div_quot, 0u64) - } else if div_quot == 0 { - (0u64, 1u64) - } else { - ((1u64 << 32) - div_quot, 0u64) - }; - let (div_rem_signed, div_rem_carry) = if rs1_sign == 0 { - (div_rem, 0u64) - } else if div_rem == 0 { - (0u64, 1u64) - } else { - ((1u64 << 32) - div_rem, 0u64) - }; - z[layout.div_quot_signed(j)] = F::from_u64(div_quot_signed); - z[layout.div_rem_signed(j)] = F::from_u64(div_rem_signed); - z[layout.div_quot_carry(j)] = F::from_u64(div_quot_carry); - z[layout.div_rem_carry(j)] = F::from_u64(div_rem_carry); - - // Shared-bus bound values: lookup_key / alu_out / mem_rv / eff_addr. - let add_has_lookup = is_add - || is_addi - || is_lb - || is_lbu - || is_lh - || is_lhu - || is_lw - || is_sb - || is_sh - || is_sw - || is_amoadd_w - || is_auipc - || is_jalr; - z[layout.add_has_lookup(j)] = if add_has_lookup { F::ONE } else { F::ZERO }; - let and_has_lookup = is_and || is_andi || is_amoand_w; - z[layout.and_has_lookup(j)] = if and_has_lookup { F::ONE } else { F::ZERO }; - let xor_has_lookup = is_xor || is_xori || is_amoxor_w; - z[layout.xor_has_lookup(j)] = if xor_has_lookup { F::ONE } else { F::ZERO }; - let or_has_lookup = is_or || is_ori || is_amoor_w; - z[layout.or_has_lookup(j)] = if or_has_lookup { F::ONE } else { F::ZERO }; - let sll_has_lookup = is_sll || is_slli; - z[layout.sll_has_lookup(j)] = if sll_has_lookup { F::ONE } else { F::ZERO }; - let srl_has_lookup = is_srl || is_srli; - z[layout.srl_has_lookup(j)] = if srl_has_lookup { F::ONE } else { F::ZERO }; - let sra_has_lookup = is_sra || is_srai; - z[layout.sra_has_lookup(j)] = if sra_has_lookup { F::ONE } else { F::ZERO }; - let slt_has_lookup = is_slt || is_slti || is_blt || is_bge; - z[layout.slt_has_lookup(j)] = if slt_has_lookup { F::ONE } else { F::ZERO }; - let sltu_has_lookup = is_sltu || is_sltiu || is_bltu || is_bgeu || do_rem_check || do_rem_check_signed; - z[layout.sltu_has_lookup(j)] = if sltu_has_lookup { F::ONE } else { F::ZERO }; - - let is_amo = is_amoswap_w || is_amoadd_w || is_amoxor_w || is_amoor_w || is_amoand_w; - let ram_has_read = is_lb || is_lbu || is_lh || is_lhu || is_lw || is_sb || is_sh || is_amo; - let ram_has_write = is_sb || is_sh || is_sw || is_amo; - z[layout.ram_has_read(j)] = if ram_has_read { F::ONE } else { F::ZERO }; - z[layout.ram_has_write(j)] = if ram_has_write { F::ONE } else { F::ZERO }; - - // Default zeros. - z[layout.lookup_key(j)] = F::ZERO; - z[layout.alu_out(j)] = F::ZERO; - z[layout.add_a0b0(j)] = F::ZERO; - z[layout.mem_rv(j)] = F::ZERO; - z[layout.eff_addr(j)] = F::ZERO; - z[layout.ram_wv(j)] = F::ZERO; - z[layout.rd_write_val(j)] = F::ZERO; - z[layout.br_taken(j)] = F::ZERO; - z[layout.br_not_taken(j)] = F::ZERO; - - // RAM events: validate shape and fill the RAM twist lane + CPU mirrors. - let is_load = is_lb || is_lbu || is_lh || is_lhu || is_lw; - let is_store_rmw = is_sb || is_sh; - if is_load { - if ram_read.is_none() || ram_write.is_some() { - return Err(format!( - "RV32 B1: load expects one RAM read and no RAM write at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - } else if is_store_rmw { - if ram_read.is_none() || ram_write.is_none() { - return Err(format!( - "RV32 B1: byte/half store expects one RAM read and one RAM write at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - } else if is_sw { - if ram_read.is_some() || ram_write.is_none() { - return Err(format!( - "RV32 B1: SW expects one RAM write and no RAM read at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - } else if is_amo { - if ram_read.is_none() || ram_write.is_none() { - return Err(format!( - "RV32 B1: AMO expects one RAM read and one RAM write at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - } else if ram_read.is_some() || ram_write.is_some() { - return Err(format!( - "RV32 B1: unexpected RAM event(s) at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - - if let (Some((read_addr, _)), Some((write_addr, _))) = (ram_read, ram_write) { - if read_addr != write_addr { - return Err(format!( - "RV32 B1: RAM read/write addr mismatch at pc={:#x} (chunk j={j}): read_addr={:#x}, write_addr={:#x}", - step.pc_before, read_addr, write_addr - )); - } - } - - if let Some((addr, value)) = ram_read { - z[layout.eff_addr(j)] = F::from_u64(addr); - z[layout.mem_rv(j)] = F::from_u64(value); - } - if let Some((addr, value)) = ram_write { - z[layout.eff_addr(j)] = F::from_u64(addr); - z[layout.ram_wv(j)] = F::from_u64(value); - } - - if fill_bus { - let ram_bus_has_read = ram_read.is_some(); - let ram_bus_has_write = ram_write.is_some(); - set_bus_cell( - &mut z, - layout, - ram_lane.has_read, - j, - if ram_bus_has_read { F::ONE } else { F::ZERO }, - ); - set_bus_cell( - &mut z, - layout, - ram_lane.has_write, - j, - if ram_bus_has_write { F::ONE } else { F::ZERO }, - ); - if let Some((addr, value)) = ram_read { - write_bus_u64_bits( - &mut z, - layout, - ram_lane.ra_bits.start, - ram_lane.ra_bits.end - ram_lane.ra_bits.start, - j, - addr, - ); - set_bus_cell(&mut z, layout, ram_lane.rv, j, F::from_u64(value)); - } - if let Some((addr, value)) = ram_write { - write_bus_u64_bits( - &mut z, - layout, - ram_lane.wa_bits.start, - ram_lane.wa_bits.end - ram_lane.wa_bits.start, - j, - addr, - ); - set_bus_cell(&mut z, layout, ram_lane.wv, j, F::from_u64(value)); - } - set_bus_cell(&mut z, layout, ram_lane.inc, j, F::ZERO); - } - - // Shout events: expect at most one lookup and bind it to a single lane. - let shout_ev = match step.shout_events.as_slice() { - [] => None, - [one] => Some(one), - _ => { - let ids: Vec = step.shout_events.iter().map(|ev| ev.shout_id.0).collect(); - return Err(format!( - "RV32 B1: multiple shout events in one step (pc={:#x}, chunk j={j}): shout_ids={ids:?}; this circuit has 1 Shout port (no lanes) so you must either provision multiple Shout lanes in the shared CPU bus or split into micro-steps", - step.pc_before - )); - } - }; - - let mut expected_table_id: Option = None; - let mut expect_table = |flag: bool, table_id: u32, name: &str| -> Result<(), String> { - if !flag { - return Ok(()); - } - if expected_table_id.replace(table_id).is_some() { - return Err(format!( - "RV32 B1: multiple Shout lookups expected at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - if layout.shout_idx(table_id).is_err() { - return Err(format!( - "RV32 B1: missing Shout table {name} (id={table_id}) at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - Ok(()) - }; - - expect_table(add_has_lookup, ADD_TABLE_ID, "ADD")?; - expect_table(and_has_lookup, AND_TABLE_ID, "AND")?; - expect_table(xor_has_lookup, XOR_TABLE_ID, "XOR")?; - expect_table(or_has_lookup, OR_TABLE_ID, "OR")?; - expect_table(sll_has_lookup, SLL_TABLE_ID, "SLL")?; - expect_table(srl_has_lookup, SRL_TABLE_ID, "SRL")?; - expect_table(sra_has_lookup, SRA_TABLE_ID, "SRA")?; - expect_table(slt_has_lookup, SLT_TABLE_ID, "SLT")?; - expect_table(sltu_has_lookup, SLTU_TABLE_ID, "SLTU")?; - expect_table(is_sub, SUB_TABLE_ID, "SUB")?; - expect_table(is_beq, EQ_TABLE_ID, "EQ")?; - expect_table(is_bne, NEQ_TABLE_ID, "NEQ")?; - - match (expected_table_id, shout_ev) { - (None, None) => {} - (None, Some(ev)) => { - return Err(format!( - "RV32 B1: unexpected shout event id={} at pc={:#x} (chunk j={j})", - ev.shout_id.0, step.pc_before - )); - } - (Some(expected), None) => { - return Err(format!( - "RV32 B1: missing shout event for table_id={expected} at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - (Some(expected), Some(ev)) => { - let got = ev.shout_id.0; - if got != expected { - return Err(format!( - "RV32 B1: shout table id mismatch at pc={:#x} (chunk j={j}): expected={expected}, got={got}", - step.pc_before - )); - } - let table_idx = layout - .shout_idx(expected) - .map_err(|e| format!("RV32 B1: {e} at pc={:#x} (chunk j={j})", step.pc_before))?; - let lane = &layout.bus.shout_cols[table_idx].lanes[0]; - if fill_bus { - set_bus_cell(&mut z, layout, lane.has_lookup, j, F::ONE); - write_bus_u64_bits( - &mut z, - layout, - lane.addr_bits.start, - lane.addr_bits.end - lane.addr_bits.start, - j, - ev.key, - ); - set_bus_cell(&mut z, layout, lane.val, j, F::from_u64(ev.value)); - } - z[layout.alu_out(j)] = F::from_u64(ev.value); - } - } - - if fill_bus { - let add_a0 = z[layout.bus.bus_cell(add_lane.addr_bits.start + 0, j)]; - let add_b0 = z[layout.bus.bus_cell(add_lane.addr_bits.start + 1, j)]; - z[layout.add_a0b0(j)] = add_a0 * add_b0; - } else if let Some(ev) = shout_ev { - if ev.shout_id.0 == ADD_TABLE_ID { - let a0 = if (ev.key & 1) == 1 { F::ONE } else { F::ZERO }; - let b0 = if ((ev.key >> 1) & 1) == 1 { F::ONE } else { F::ZERO }; - z[layout.add_a0b0(j)] = a0 * b0; - } - } - - let rs1_i32 = rs1_u32 as i32; - let rs2_i32 = rs2_u32 as i32; - let mulh_u32 = if is_mulh { - ((rs1_i32 as i64 * rs2_i32 as i64) >> 32) as i32 as u32 - } else { - 0u32 - }; - let mulhsu_u32 = if is_mulhsu { - ((rs1_i32 as i64 * rs2_u32 as i64) >> 32) as i32 as u32 - } else { - 0u32 - }; - - // Writeback value. - if is_add - || is_sub - || is_sll - || is_slt - || is_sltu - || is_xor - || is_srl - || is_sra - || is_or - || is_and - || is_addi - || is_slti - || is_sltiu - || is_xori - || is_ori - || is_andi - || is_slli - || is_srli - || is_srai - || is_auipc - { - z[layout.rd_write_val(j)] = z[layout.alu_out(j)]; - } - if is_mul { - z[layout.rd_write_val(j)] = F::from_u64(mul_lo); - } - if is_mulhu { - z[layout.rd_write_val(j)] = F::from_u64(mul_hi); - } - if is_mulh { - z[layout.rd_write_val(j)] = F::from_u64(mulh_u32 as u64); - } - if is_mulhsu { - z[layout.rd_write_val(j)] = F::from_u64(mulhsu_u32 as u64); - } - if is_divu { - z[layout.rd_write_val(j)] = z[layout.div_quot(j)]; - } - if is_remu { - z[layout.rd_write_val(j)] = z[layout.div_rem(j)]; - } - if is_div { - if rs2_is_zero { - z[layout.rd_write_val(j)] = F::from_u64(u32::MAX as u64); - } else { - z[layout.rd_write_val(j)] = z[layout.div_quot_signed(j)]; - } - } - if is_rem { - z[layout.rd_write_val(j)] = z[layout.div_rem_signed(j)]; - } - if is_lw || is_amoswap_w || is_amoadd_w || is_amoxor_w || is_amoor_w || is_amoand_w { - z[layout.rd_write_val(j)] = z[layout.mem_rv(j)]; - } - if is_lb || is_lbu || is_lh || is_lhu { - let mem_rv_u64 = z[layout.mem_rv(j)].as_canonical_u64(); - let mem_rv_u32 = u32::try_from(mem_rv_u64).map_err(|_| { - format!( - "RV32 B1: mem_rv does not fit in u32 at pc={:#x}: {mem_rv_u64}", - step.pc_before - ) - })?; - if is_lb { - let byte = (mem_rv_u32 & 0xff) as u8; - z[layout.rd_write_val(j)] = F::from_u64((byte as i8 as i32 as u32) as u64); - } - if is_lbu { - z[layout.rd_write_val(j)] = F::from_u64((mem_rv_u32 & 0xff) as u64); - } - if is_lh { - let half = (mem_rv_u32 & 0xffff) as u16; - z[layout.rd_write_val(j)] = F::from_u64((half as i16 as i32 as u32) as u64); - } - if is_lhu { - z[layout.rd_write_val(j)] = F::from_u64((mem_rv_u32 & 0xffff) as u64); - } - } - if is_lui { - z[layout.rd_write_val(j)] = z[layout.imm_u(j)]; - } - if is_jal || is_jalr { - z[layout.rd_write_val(j)] = F::from_u64(step.pc_before.wrapping_add(4)); - } - if is_beq || is_bne || is_blt || is_bltu { - z[layout.br_taken(j)] = z[layout.alu_out(j)]; - z[layout.br_not_taken(j)] = F::ONE - z[layout.br_taken(j)]; - } - if is_bge || is_bgeu { - z[layout.br_taken(j)] = F::ONE - z[layout.alu_out(j)]; - z[layout.br_not_taken(j)] = F::ONE - z[layout.br_taken(j)]; - } - - let mul_carry = if is_mulh { - let rhs = - (mul_hi as i128) - (rs1_sign as i128) * (rs2_u64 as i128) - (rs2_sign as i128) * (rs1_u64 as i128) - + (rs1_sign as i128) * (rs2_sign as i128) * (1i128 << 32) - + (1i128 << 32); - let diff = rhs - (mulh_u32 as i128); - if diff < 0 || diff % (1i128 << 32) != 0 { - return Err(format!( - "RV32 B1: MULH carry mismatch at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - (diff >> 32) as u64 - } else if is_mulhsu { - let rhs = (mul_hi as i128) - (rs1_sign as i128) * (rs2_u64 as i128) + (1i128 << 32); - let diff = rhs - (mulhsu_u32 as i128); - if diff < 0 || diff % (1i128 << 32) != 0 { - return Err(format!( - "RV32 B1: MULHSU carry mismatch at pc={:#x} (chunk j={j})", - step.pc_before - )); - } - (diff >> 32) as u64 - } else { - 0u64 - }; - z[layout.mul_carry(j)] = F::from_u64(mul_carry); - - let rd_write_u64 = z[layout.rd_write_val(j)].as_canonical_u64(); - let rd_write_u32 = u32::try_from(rd_write_u64) - .map_err(|_| format!("RV32 B1: rd_write_val does not fit in u32: {rd_write_u64}"))?; - let mem_rv_u64 = z[layout.mem_rv(j)].as_canonical_u64(); - let mem_rv_u32 = - u32::try_from(mem_rv_u64).map_err(|_| format!("RV32 B1: mem_rv does not fit in u32: {mem_rv_u64}"))?; - let div_quot_u32 = u32::try_from(div_quot).map_err(|_| "RV32 B1: div_quot overflow".to_string())?; - let div_rem_u32 = u32::try_from(div_rem).map_err(|_| "RV32 B1: div_rem overflow".to_string())?; - - for bit in 0..32 { - z[layout.rd_write_bit(bit, j)] = if ((rd_write_u32 >> bit) & 1) == 1 { - F::ONE - } else { - F::ZERO - }; - z[layout.mem_rv_bit(bit, j)] = if ((mem_rv_u32 >> bit) & 1) == 1 { - F::ONE - } else { - F::ZERO - }; - z[layout.mul_lo_bit(bit, j)] = if ((mul_lo as u32 >> bit) & 1) == 1 { - F::ONE - } else { - F::ZERO - }; - z[layout.mul_hi_bit(bit, j)] = if ((mul_hi as u32 >> bit) & 1) == 1 { - F::ONE - } else { - F::ZERO - }; - z[layout.div_quot_bit(bit, j)] = if ((div_quot_u32 >> bit) & 1) == 1 { - F::ONE - } else { - F::ZERO - }; - z[layout.div_rem_bit(bit, j)] = if ((div_rem_u32 >> bit) & 1) == 1 { - F::ONE - } else { - F::ZERO - }; - } - for bit in 0..2 { - z[layout.mul_carry_bit(bit, j)] = if ((mul_carry >> bit) & 1) == 1 { F::ONE } else { F::ZERO }; - } - - let mut prefix = if (mul_hi as u32 & 1) == 1 { F::ONE } else { F::ZERO }; - z[layout.mul_hi_prefix(0, j)] = prefix; - for k in 1..31 { - let bit = ((mul_hi as u32 >> k) & 1) == 1; - prefix *= if bit { F::ONE } else { F::ZERO }; - z[layout.mul_hi_prefix(k, j)] = prefix; - } - } - - z[layout.pc_final] = F::from_u64(carried_pc); - for r in 0..32 { - z[layout.regs_final_start + r] = F::from_u64(carried_regs[r]); - } - z[layout.regs_final_start] = F::ZERO; - - // Chunk-level halting state used for cross-chunk padding semantics. - z[layout.halted_in] = F::ONE - z[layout.is_active(0)]; - let j_last = layout.chunk_size - 1; - z[layout.halted_out] = F::ONE - z[layout.is_active(j_last)] + z[layout.halt_effective(j_last)]; - - Ok(z) -} diff --git a/crates/neo-memory/src/riscv/exec_table.rs b/crates/neo-memory/src/riscv/exec_table.rs new file mode 100644 index 00000000..61d71b82 --- /dev/null +++ b/crates/neo-memory/src/riscv/exec_table.rs @@ -0,0 +1,1067 @@ +use neo_vm_trace::{ShoutEvent, StepTrace, TwistEvent, TwistOpKind, VmTrace}; + +use crate::riscv::lookups::{ + compute_op, decode_instruction, interleave_bits, uninterleave_bits, RiscvInstruction, RiscvOpcode, + RiscvShoutTables, PROG_ID, RAM_ID, REG_ID, +}; +use std::collections::HashMap; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Rv32InstrFields { + pub opcode: u32, + pub rd: u8, + pub funct3: u32, + pub rs1: u8, + pub rs2: u8, + pub funct7: u32, +} + +impl Rv32InstrFields { + pub fn from_word(instr_word: u32) -> Self { + Self { + opcode: instr_word & 0x7f, + rd: ((instr_word >> 7) & 0x1f) as u8, + funct3: (instr_word >> 12) & 0x7, + rs1: ((instr_word >> 15) & 0x1f) as u8, + rs2: ((instr_word >> 20) & 0x1f) as u8, + funct7: (instr_word >> 25) & 0x7f, + } + } +} + +#[derive(Clone, Debug)] +pub struct Rv32RegLaneIo { + pub addr: u64, + pub value: u64, +} + +#[derive(Clone, Debug)] +pub struct Rv32ExecRow { + /// True for real trace rows; false for padded/inactive rows. + pub active: bool, + + pub cycle: u64, + pub pc_before: u64, + pub pc_after: u64, + pub instr_word: u32, + pub fields: Rv32InstrFields, + pub halted: bool, + + /// Decoded instruction (for semantic context; derived from `instr_word`). + pub decoded: Option, + + /// PROG ROM fetch (`PROG_ID`) for this step. + pub prog_read: Option>, + + /// REG lane 0 read (`REG_ID`, lane=0): rs1_field → rs1_val. + pub reg_read_lane0: Option, + + /// REG lane 1 read (`REG_ID`, lane=1): rs2_field → rs2_val. + pub reg_read_lane1: Option, + + /// Optional REG lane 0 write (`REG_ID`, lane=0): rd_field → rd_write_val. + pub reg_write_lane0: Option, + + /// RAM twist events (`RAM_ID`) for this step. + pub ram_events: Vec>, + + /// Shout events for this step. + pub shout_events: Vec>, +} + +#[derive(Clone, Debug)] +pub struct Rv32ExecColumns { + pub active: Vec, + pub cycle: Vec, + pub pc_before: Vec, + pub pc_after: Vec, + pub instr_word: Vec, + pub opcode: Vec, + pub rd: Vec, + pub funct3: Vec, + pub rs1: Vec, + pub rs2: Vec, + pub funct7: Vec, + pub halted: Vec, + pub prog_addr: Vec, + pub prog_value: Vec, + pub rs1_addr: Vec, + pub rs1_val: Vec, + pub rs2_addr: Vec, + pub rs2_val: Vec, + pub rd_has_write: Vec, + pub rd_addr: Vec, + pub rd_val: Vec, +} + +impl Rv32ExecColumns { + pub fn len(&self) -> usize { + self.cycle.len() + } +} + +#[derive(Clone, Debug)] +pub struct Rv32ExecTable { + pub rows: Vec, +} + +impl Rv32ExecTable { + pub fn from_trace(trace: &VmTrace) -> Result { + let mut rows = Vec::with_capacity(trace.steps.len()); + for step in &trace.steps { + rows.push(Rv32ExecRow::from_step(step)?); + } + Ok(Self { rows }) + } + + pub fn from_trace_padded(trace: &VmTrace, padded_len: usize) -> Result { + if padded_len < trace.steps.len() { + return Err(format!( + "padded_len must be >= trace length (padded_len={} trace_len={})", + padded_len, + trace.steps.len() + )); + } + + let mut rows = Vec::with_capacity(padded_len); + for step in &trace.steps { + rows.push(Rv32ExecRow::from_step(step)?); + } + if rows.is_empty() { + if padded_len == 0 { + return Ok(Self { rows }); + } + return Err("cannot pad empty trace without an initial pc".into()); + } + + let last = rows.last().expect("rows non-empty"); + let mut cycle = last.cycle; + let pad_pc = last.pc_after; + let pad_halted = last.halted; + + while rows.len() < padded_len { + cycle = cycle + .checked_add(1) + .ok_or_else(|| "cycle overflow while padding".to_string())?; + rows.push(Rv32ExecRow::inactive(cycle, pad_pc, pad_halted)); + } + + Ok(Self { rows }) + } + + pub fn from_trace_padded_pow2(trace: &VmTrace, min_len: usize) -> Result { + let steps = trace.steps.len(); + let target = steps.max(min_len).next_power_of_two(); + Self::from_trace_padded(trace, target) + } + + pub fn validate_pc_chain(&self) -> Result<(), String> { + for w in self.rows.windows(2) { + let a = &w[0]; + let b = &w[1]; + if a.pc_after != b.pc_before { + return Err(format!( + "pc chain mismatch: cycle {} pc_after={:#x} != cycle {} pc_before={:#x}", + a.cycle, a.pc_after, b.cycle, b.pc_before + )); + } + } + Ok(()) + } + + /// Validate that cycles are consecutive (`cycle[t+1] = cycle[t] + 1`). + pub fn validate_cycle_chain(&self) -> Result<(), String> { + for w in self.rows.windows(2) { + let a = &w[0]; + let b = &w[1]; + if b.cycle != a.cycle + 1 { + return Err(format!( + "cycle chain mismatch: cycle {} then {} (expected {})", + a.cycle, + b.cycle, + a.cycle + 1 + )); + } + } + Ok(()) + } + + /// Validate that inactive rows contain no events and no decoded instruction. + pub fn validate_inactive_rows_are_empty(&self) -> Result<(), String> { + for r in &self.rows { + if r.active { + continue; + } + if r.decoded.is_some() + || r.prog_read.is_some() + || r.reg_read_lane0.is_some() + || r.reg_read_lane1.is_some() + || r.reg_write_lane0.is_some() + || !r.ram_events.is_empty() + || !r.shout_events.is_empty() + { + return Err(format!("inactive row has events/decoded at cycle {}", r.cycle)); + } + } + Ok(()) + } + + /// Validate that once `halted` becomes true, it stays true and the PC stops changing. + pub fn validate_halted_tail(&self) -> Result<(), String> { + let mut saw_halt = false; + let mut halt_pc: Option = None; + for r in &self.rows { + if !saw_halt { + if r.halted { + saw_halt = true; + // In our trace semantics, the HALT row itself can advance the PC (default +4), + // but after that the machine is halted and PC should stop changing. + halt_pc = Some(r.pc_after); + } + continue; + } + + if !r.halted { + return Err(format!( + "halted tail violated: halted dropped to false at cycle {} (pc_before={:#x})", + r.cycle, r.pc_before + )); + } + + let pc0 = halt_pc.expect("halt_pc set"); + if r.pc_before != pc0 || r.pc_after != pc0 { + return Err(format!( + "halted tail violated: pc changed after halt at cycle {} (pc_before={:#x} pc_after={:#x}, expected {:#x})", + r.cycle, r.pc_before, r.pc_after, pc0 + )); + } + } + Ok(()) + } + + /// Validate strict JALR next-PC policy used by trace-wiring control claims. + /// + /// Current trace-wiring control stage enforces `pc_after = rs1_val + imm_i` for JALR rows + /// (no committed drop-bit helper columns). Under this policy, traces requiring ISA-level + /// JALR masking are out of scope and must be rejected during trace construction. + pub fn validate_jalr_strict_alignment_policy(&self) -> Result<(), String> { + for r in &self.rows { + if !r.active { + continue; + } + let Some(crate::riscv::lookups::RiscvInstruction::Jalr { imm, .. }) = r.decoded.as_ref() else { + continue; + }; + let rs1_val = r.reg_read_lane0.as_ref().map(|io| io.value).unwrap_or(0); + let imm_u32 = *imm as u32 as u64; + let expected_pc_after = rs1_val.wrapping_add(imm_u32); + if r.pc_after != expected_pc_after { + return Err(format!( + "strict JALR policy violated at cycle {}: pc_after={:#x}, expected rs1+imm={:#x} (rs1={:#x}, imm={:#x})", + r.cycle, r.pc_after, expected_pc_after, rs1_val, imm_u32 + )); + } + } + Ok(()) + } + + /// Validate REG lane semantics by replaying the register file from an initial state. + /// + /// - `init_regs` maps `reg_idx (0..31)` → value (u32 stored in u64). + /// - Unspecified registers default to 0. + /// - Reads happen before the optional lane0 write in each cycle. + pub fn validate_regfile_semantics(&self, init_regs: &HashMap) -> Result<(), String> { + let mut regs = [0u64; 32]; + for (&addr, &value) in init_regs { + if addr >= 32 { + return Err(format!("reg init addr out of range: addr={addr}")); + } + if addr == 0 && value != 0 { + return Err("reg init must keep x0 == 0".into()); + } + regs[addr as usize] = value; + } + + for r in &self.rows { + if !r.active { + continue; + } + + let Some(rs1) = &r.reg_read_lane0 else { + return Err(format!("missing REG lane0 read at cycle {}", r.cycle)); + }; + let Some(rs2) = &r.reg_read_lane1 else { + return Err(format!("missing REG lane1 read at cycle {}", r.cycle)); + }; + if rs1.addr >= 32 || rs2.addr >= 32 { + return Err(format!( + "REG read addr out of range at cycle {}: lane0={} lane1={}", + r.cycle, rs1.addr, rs2.addr + )); + } + + let exp_rs1 = regs[rs1.addr as usize]; + let exp_rs2 = regs[rs2.addr as usize]; + if rs1.value != exp_rs1 { + return Err(format!( + "REG lane0 read value mismatch at cycle {} pc={:#x}: addr={} got={:#x} expected={:#x}", + r.cycle, r.pc_before, rs1.addr, rs1.value, exp_rs1 + )); + } + if rs2.value != exp_rs2 { + return Err(format!( + "REG lane1 read value mismatch at cycle {} pc={:#x}: addr={} got={:#x} expected={:#x}", + r.cycle, r.pc_before, rs2.addr, rs2.value, exp_rs2 + )); + } + + if let Some(w) = &r.reg_write_lane0 { + if w.addr >= 32 { + return Err(format!( + "REG write addr out of range at cycle {}: addr={}", + r.cycle, w.addr + )); + } + if w.addr == 0 { + return Err(format!( + "unexpected x0 write at cycle {} pc={:#x}", + r.cycle, r.pc_before + )); + } + regs[w.addr as usize] = w.value; + } + + // x0 is always 0. + regs[0] = 0; + } + + Ok(()) + } + + /// Validate RAM twist semantics by replaying the RAM state from an initial state. + /// + /// - `init_ram` maps `byte_addr` → word value (u32 stored in u64) under the RV32 trace convention. + /// - Unspecified addresses default to 0. + /// - Multiple RAM events in a cycle are applied in trace order (e.g. SB/SH read-modify-write). + pub fn validate_ram_semantics(&self, init_ram: &HashMap) -> Result<(), String> { + let mut mem: HashMap = HashMap::new(); + for (&addr, &value) in init_ram { + if value == 0 { + continue; + } + mem.insert(addr, value); + } + + for r in &self.rows { + if !r.active { + continue; + } + + for e in &r.ram_events { + match e.kind { + TwistOpKind::Read => { + let exp = mem.get(&e.addr).copied().unwrap_or(0); + if e.value != exp { + return Err(format!( + "RAM read value mismatch at cycle {} pc={:#x}: addr={:#x} got={:#x} expected={:#x}", + r.cycle, r.pc_before, e.addr, e.value, exp + )); + } + } + TwistOpKind::Write => { + if e.value == 0 { + mem.remove(&e.addr); + } else { + mem.insert(e.addr, e.value); + } + } + } + } + } + + Ok(()) + } + + pub fn to_columns(&self) -> Rv32ExecColumns { + let n = self.rows.len(); + + let mut out = Rv32ExecColumns { + active: Vec::with_capacity(n), + cycle: Vec::with_capacity(n), + pc_before: Vec::with_capacity(n), + pc_after: Vec::with_capacity(n), + instr_word: Vec::with_capacity(n), + opcode: Vec::with_capacity(n), + rd: Vec::with_capacity(n), + funct3: Vec::with_capacity(n), + rs1: Vec::with_capacity(n), + rs2: Vec::with_capacity(n), + funct7: Vec::with_capacity(n), + halted: Vec::with_capacity(n), + prog_addr: Vec::with_capacity(n), + prog_value: Vec::with_capacity(n), + rs1_addr: Vec::with_capacity(n), + rs1_val: Vec::with_capacity(n), + rs2_addr: Vec::with_capacity(n), + rs2_val: Vec::with_capacity(n), + rd_has_write: Vec::with_capacity(n), + rd_addr: Vec::with_capacity(n), + rd_val: Vec::with_capacity(n), + }; + + for r in &self.rows { + out.active.push(r.active); + out.cycle.push(r.cycle); + out.pc_before.push(r.pc_before); + out.pc_after.push(r.pc_after); + out.instr_word.push(r.instr_word); + out.opcode.push(r.fields.opcode); + out.rd.push(r.fields.rd); + out.funct3.push(r.fields.funct3); + out.rs1.push(r.fields.rs1); + out.rs2.push(r.fields.rs2); + out.funct7.push(r.fields.funct7); + out.halted.push(r.halted); + + match &r.prog_read { + Some(e) => { + out.prog_addr.push(e.addr); + out.prog_value.push(e.value); + } + None => { + out.prog_addr.push(0); + out.prog_value.push(0); + } + } + + match &r.reg_read_lane0 { + Some(io) => { + out.rs1_addr.push(io.addr); + out.rs1_val.push(io.value); + } + None => { + out.rs1_addr.push(0); + out.rs1_val.push(0); + } + } + + match &r.reg_read_lane1 { + Some(io) => { + out.rs2_addr.push(io.addr); + out.rs2_val.push(io.value); + } + None => { + out.rs2_addr.push(0); + out.rs2_val.push(0); + } + } + + match &r.reg_write_lane0 { + Some(io) => { + out.rd_has_write.push(true); + out.rd_addr.push(io.addr); + out.rd_val.push(io.value); + } + None => { + out.rd_has_write.push(false); + out.rd_addr.push(0); + out.rd_val.push(0); + } + } + } + + out + } +} + +impl Rv32ExecRow { + pub fn from_step(step: &StepTrace) -> Result { + let instr_word = step.opcode; + let fields = Rv32InstrFields::from_word(instr_word); + let decoded = decode_instruction(instr_word).map_err(|e| { + format!( + "decode_instruction failed at cycle {} pc={:#x} word={:#x}: {e}", + step.cycle, step.pc_before, instr_word + ) + })?; + + // PROG fetch + let prog_read = { + let mut reads = step + .twist_events + .iter() + .filter(|e| e.twist_id == PROG_ID && matches!(e.kind, TwistOpKind::Read)) + .cloned(); + let first = reads.next().ok_or_else(|| { + format!( + "missing PROG_ID read event at cycle {} pc={:#x}", + step.cycle, step.pc_before + ) + })?; + if reads.next().is_some() { + return Err(format!( + "expected exactly 1 PROG_ID read event at cycle {} pc={:#x}", + step.cycle, step.pc_before + )); + } + first + }; + if prog_read.addr != step.pc_before { + return Err(format!( + "PROG_ID read addr mismatch at cycle {}: got={:#x} expected pc_before={:#x}", + step.cycle, prog_read.addr, step.pc_before + )); + } + if prog_read.value != instr_word as u64 { + return Err(format!( + "PROG_ID read value mismatch at cycle {} pc={:#x}: got={:#x} expected instr_word={:#x}", + step.cycle, step.pc_before, prog_read.value, instr_word + )); + } + if prog_read.lane.is_some() { + return Err(format!( + "unexpected PROG_ID lane hint at cycle {} pc={:#x}: lane={:?}", + step.cycle, step.pc_before, prog_read.lane + )); + } + + // REG reads (lane 0 and lane 1) + let mut reg_read_lane0: Option = None; + let mut reg_read_lane1: Option = None; + let mut reg_write_lane0: Option = None; + for e in step.twist_events.iter().filter(|e| e.twist_id == REG_ID) { + match e.kind { + TwistOpKind::Read => match e.lane { + Some(0) => { + if reg_read_lane0.is_some() { + return Err(format!( + "duplicate REG_ID lane 0 read at cycle {} pc={:#x}", + step.cycle, step.pc_before + )); + } + reg_read_lane0 = Some(Rv32RegLaneIo { + addr: e.addr, + value: e.value, + }); + } + Some(1) => { + if reg_read_lane1.is_some() { + return Err(format!( + "duplicate REG_ID lane 1 read at cycle {} pc={:#x}", + step.cycle, step.pc_before + )); + } + reg_read_lane1 = Some(Rv32RegLaneIo { + addr: e.addr, + value: e.value, + }); + } + other => { + return Err(format!( + "unexpected REG_ID read lane {:?} at cycle {} pc={:#x}", + other, step.cycle, step.pc_before + )); + } + }, + TwistOpKind::Write => match e.lane { + Some(0) => { + if reg_write_lane0.is_some() { + return Err(format!( + "duplicate REG_ID lane 0 write at cycle {} pc={:#x}", + step.cycle, step.pc_before + )); + } + reg_write_lane0 = Some(Rv32RegLaneIo { + addr: e.addr, + value: e.value, + }); + } + other => { + return Err(format!( + "unexpected REG_ID write lane {:?} at cycle {} pc={:#x}", + other, step.cycle, step.pc_before + )); + } + }, + } + } + let reg_read_lane0 = reg_read_lane0.ok_or_else(|| { + format!( + "missing REG_ID lane 0 read at cycle {} pc={:#x}", + step.cycle, step.pc_before + ) + })?; + let reg_read_lane1 = reg_read_lane1.ok_or_else(|| { + format!( + "missing REG_ID lane 1 read at cycle {} pc={:#x}", + step.cycle, step.pc_before + ) + })?; + if let Some(w) = ®_write_lane0 { + if fields.rd == 0 { + return Err(format!( + "unexpected REG_ID lane 0 write to x0 at cycle {} pc={:#x}", + step.cycle, step.pc_before + )); + } + if w.addr != fields.rd as u64 { + return Err(format!( + "REG lane0 write addr mismatch at cycle {} pc={:#x}: got={} expected rd_field={}", + step.cycle, step.pc_before, w.addr, fields.rd + )); + } + } + + // Light sanity check: make sure the trace's lane policy matches RV32 trace conventions. + // + // - lane0 reads rs1_field always + // - lane1 reads rs2_field + let rs2_expected = fields.rs2 as u64; + if reg_read_lane0.addr != fields.rs1 as u64 { + return Err(format!( + "REG lane0 read addr mismatch at cycle {} pc={:#x}: got={} expected rs1_field={}", + step.cycle, step.pc_before, reg_read_lane0.addr, fields.rs1 + )); + } + if reg_read_lane1.addr != rs2_expected { + return Err(format!( + "REG lane1 read addr mismatch at cycle {} pc={:#x}: got={} expected={}", + step.cycle, step.pc_before, reg_read_lane1.addr, rs2_expected + )); + } + + // RAM events + let ram_events: Vec> = step + .twist_events + .iter() + .filter(|e| e.twist_id == RAM_ID) + .cloned() + .collect(); + + // Shout events + let mut shout_events = step.shout_events.clone(); + if shout_events.is_empty() { + // Backfill RV32M shout events for trace/event-table consumers. + // + // Some trace builders currently omit explicit Shout events for RV32M rows even when + // the operation is semantically Shout-backed. Reconstruct the canonical event from the + // decoded op and the architectural operands. + if let RiscvInstruction::RAlu { op, .. } = &decoded { + let is_rv32m = matches!( + op, + RiscvOpcode::Mul + | RiscvOpcode::Mulh + | RiscvOpcode::Mulhu + | RiscvOpcode::Mulhsu + | RiscvOpcode::Div + | RiscvOpcode::Divu + | RiscvOpcode::Rem + | RiscvOpcode::Remu + ); + if is_rv32m { + let rs1_val = reg_read_lane0.value; + let rs2_val = reg_read_lane1.value; + let shout_id = RiscvShoutTables::new(/*xlen=*/ 32).opcode_to_id(*op); + let key = interleave_bits(rs1_val, rs2_val) as u64; + let value = compute_op(*op, rs1_val, rs2_val, /*xlen=*/ 32); + shout_events.push(ShoutEvent { shout_id, key, value }); + } + } + } + + Ok(Self { + active: true, + cycle: step.cycle, + pc_before: step.pc_before, + pc_after: step.pc_after, + instr_word, + fields, + halted: step.halted, + decoded: Some(decoded), + prog_read: Some(prog_read), + reg_read_lane0: Some(reg_read_lane0), + reg_read_lane1: Some(reg_read_lane1), + reg_write_lane0, + ram_events, + shout_events, + }) + } + + pub fn inactive(cycle: u64, pc: u64, halted: bool) -> Self { + Self { + active: false, + cycle, + pc_before: pc, + pc_after: pc, + instr_word: 0, + fields: Rv32InstrFields::from_word(0), + halted, + decoded: None, + prog_read: None, + reg_read_lane0: None, + reg_read_lane1: None, + reg_write_lane0: None, + ram_events: Vec::new(), + shout_events: Vec::new(), + } + } +} + +#[derive(Clone, Debug)] +pub struct Rv32MEventRow { + pub cycle: u64, + pub pc: u64, + pub opcode: RiscvOpcode, + pub rs1: u8, + pub rs2: u8, + pub rd: u8, + pub rs1_val: u64, + pub rs2_val: u64, + pub rd_write_val: Option, + pub expected_rd_val: u64, +} + +#[derive(Clone, Debug)] +pub struct Rv32MEventTable { + pub rows: Vec, +} + +impl Rv32MEventTable { + pub fn from_exec_table(exec: &Rv32ExecTable) -> Result { + let mut rows = Vec::new(); + + for r in &exec.rows { + if !r.active { + continue; + } + let Some(decoded) = &r.decoded else { + continue; + }; + let (op, rd, rs1, rs2) = match decoded { + RiscvInstruction::RAlu { op, rd, rs1, rs2 } => (*op, *rd, *rs1, *rs2), + _ => continue, + }; + + let is_rv32m = matches!( + op, + RiscvOpcode::Mul + | RiscvOpcode::Mulh + | RiscvOpcode::Mulhu + | RiscvOpcode::Mulhsu + | RiscvOpcode::Div + | RiscvOpcode::Divu + | RiscvOpcode::Rem + | RiscvOpcode::Remu + ); + if !is_rv32m { + continue; + } + + let rs1_val = r + .reg_read_lane0 + .as_ref() + .ok_or_else(|| format!("missing REG lane0 read on RV32M row at cycle {}", r.cycle))? + .value; + let rs2_val = r + .reg_read_lane1 + .as_ref() + .ok_or_else(|| format!("missing REG lane1 read on RV32M row at cycle {}", r.cycle))? + .value; + let expected = compute_op(op, rs1_val, rs2_val, /*xlen=*/ 32); + let rd_write_val = r.reg_write_lane0.as_ref().map(|w| w.value); + + // The trace should not write to x0; keep the event row but require no write event. + if rd == 0 && rd_write_val.is_some() { + return Err(format!( + "unexpected x0 write event on RV32M row at cycle {} pc={:#x}", + r.cycle, r.pc_before + )); + } + if rd != 0 && rd_write_val.is_none() { + return Err(format!( + "missing rd write event on RV32M row at cycle {} pc={:#x} (rd={rd})", + r.cycle, r.pc_before + )); + } + + rows.push(Rv32MEventRow { + cycle: r.cycle, + pc: r.pc_before, + opcode: op, + rs1, + rs2, + rd, + rs1_val, + rs2_val, + rd_write_val, + expected_rd_val: expected, + }); + } + + Ok(Self { rows }) + } +} + +#[derive(Clone, Debug)] +pub struct Rv32ShoutEventRow { + /// Row index within the padded exec table (0..t). + pub row_idx: usize, + pub cycle: u64, + pub pc: u64, + pub shout_id: u32, + pub opcode: Option, + /// Canonicalized key: for shift ops, `rhs` is masked to 5 bits. + pub key: u64, + pub lhs: u64, + pub rhs: u64, + pub value: u64, +} + +#[derive(Clone, Debug)] +pub struct Rv32ShoutEventTable { + pub rows: Vec, +} + +impl Rv32ShoutEventTable { + pub fn from_exec_table(exec: &Rv32ExecTable) -> Result { + let shout_tables = RiscvShoutTables::new(/*xlen=*/ 32); + let mut rows = Vec::new(); + + for (row_idx, r) in exec.rows.iter().enumerate() { + if !r.active { + continue; + } + for ev in r.shout_events.iter() { + let opcode = shout_tables.id_to_opcode(ev.shout_id); + let (lhs, rhs_raw) = uninterleave_bits(ev.key as u128); + let rhs = if matches!(opcode, Some(RiscvOpcode::Sll | RiscvOpcode::Srl | RiscvOpcode::Sra)) { + rhs_raw & 0x1F + } else { + rhs_raw + }; + let key = if rhs != rhs_raw { + interleave_bits(lhs, rhs) as u64 + } else { + ev.key + }; + + rows.push(Rv32ShoutEventRow { + row_idx, + cycle: r.cycle, + pc: r.pc_before, + shout_id: ev.shout_id.0, + opcode, + key, + lhs, + rhs, + value: ev.value, + }); + } + } + + Ok(Self { rows }) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Rv32RegEventKind { + ReadLane0, + ReadLane1, + WriteLane0, +} + +#[derive(Clone, Debug)] +pub struct Rv32RegEventRow { + pub cycle: u64, + pub pc: u64, + pub kind: Rv32RegEventKind, + pub addr: u8, + pub prev_val: u64, + pub next_val: u64, +} + +#[derive(Clone, Debug)] +pub struct Rv32RegEventTable { + pub rows: Vec, +} + +impl Rv32RegEventTable { + pub fn from_exec_table(exec: &Rv32ExecTable, init_regs: &HashMap) -> Result { + let mut regs = [0u64; 32]; + for (&addr, &value) in init_regs { + if addr >= 32 { + return Err(format!("reg init addr out of range: addr={addr}")); + } + if addr == 0 && value != 0 { + return Err("reg init must keep x0 == 0".into()); + } + regs[addr as usize] = value; + } + + let mut rows: Vec = Vec::new(); + for r in &exec.rows { + if !r.active { + continue; + } + + let Some(rs1) = &r.reg_read_lane0 else { + return Err(format!("missing REG lane0 read at cycle {}", r.cycle)); + }; + let Some(rs2) = &r.reg_read_lane1 else { + return Err(format!("missing REG lane1 read at cycle {}", r.cycle)); + }; + if rs1.addr >= 32 || rs2.addr >= 32 { + return Err(format!( + "REG read addr out of range at cycle {}: lane0={} lane1={}", + r.cycle, rs1.addr, rs2.addr + )); + } + + // Reads happen before the optional write. + let rs1_prev = regs[rs1.addr as usize]; + let rs2_prev = regs[rs2.addr as usize]; + if rs1.value != rs1_prev { + return Err(format!( + "REG lane0 read value mismatch at cycle {} pc={:#x}: addr={} got={:#x} expected={:#x}", + r.cycle, r.pc_before, rs1.addr, rs1.value, rs1_prev + )); + } + if rs2.value != rs2_prev { + return Err(format!( + "REG lane1 read value mismatch at cycle {} pc={:#x}: addr={} got={:#x} expected={:#x}", + r.cycle, r.pc_before, rs2.addr, rs2.value, rs2_prev + )); + } + + rows.push(Rv32RegEventRow { + cycle: r.cycle, + pc: r.pc_before, + kind: Rv32RegEventKind::ReadLane0, + addr: rs1.addr as u8, + prev_val: rs1_prev, + next_val: rs1_prev, + }); + rows.push(Rv32RegEventRow { + cycle: r.cycle, + pc: r.pc_before, + kind: Rv32RegEventKind::ReadLane1, + addr: rs2.addr as u8, + prev_val: rs2_prev, + next_val: rs2_prev, + }); + + if let Some(w) = &r.reg_write_lane0 { + if w.addr >= 32 { + return Err(format!( + "REG write addr out of range at cycle {}: addr={}", + r.cycle, w.addr + )); + } + if w.addr == 0 { + return Err(format!( + "unexpected x0 write at cycle {} pc={:#x}", + r.cycle, r.pc_before + )); + } + + let prev = regs[w.addr as usize]; + let next = w.value; + regs[w.addr as usize] = next; + regs[0] = 0; + + rows.push(Rv32RegEventRow { + cycle: r.cycle, + pc: r.pc_before, + kind: Rv32RegEventKind::WriteLane0, + addr: w.addr as u8, + prev_val: prev, + next_val: next, + }); + } + } + + Ok(Self { rows }) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Rv32RamEventKind { + Read, + Write, +} + +#[derive(Clone, Debug)] +pub struct Rv32RamEventRow { + pub cycle: u64, + pub pc: u64, + pub kind: Rv32RamEventKind, + pub addr: u64, + pub prev_val: u64, + pub next_val: u64, +} + +#[derive(Clone, Debug)] +pub struct Rv32RamEventTable { + pub rows: Vec, +} + +impl Rv32RamEventTable { + pub fn from_exec_table(exec: &Rv32ExecTable, init_ram: &HashMap) -> Result { + let mut mem: HashMap = HashMap::new(); + for (&addr, &value) in init_ram { + if value == 0 { + continue; + } + mem.insert(addr, value); + } + + let mut rows: Vec = Vec::new(); + for r in &exec.rows { + if !r.active { + continue; + } + + for e in &r.ram_events { + match e.kind { + TwistOpKind::Read => { + let prev = mem.get(&e.addr).copied().unwrap_or(0); + let next = prev; + if e.value != prev { + return Err(format!( + "RAM read value mismatch at cycle {} pc={:#x}: addr={:#x} got={:#x} expected={:#x}", + r.cycle, r.pc_before, e.addr, e.value, prev + )); + } + rows.push(Rv32RamEventRow { + cycle: r.cycle, + pc: r.pc_before, + kind: Rv32RamEventKind::Read, + addr: e.addr, + prev_val: prev, + next_val: next, + }); + } + TwistOpKind::Write => { + let prev = mem.get(&e.addr).copied().unwrap_or(0); + let next = e.value; + if next == 0 { + mem.remove(&e.addr); + } else { + mem.insert(e.addr, next); + } + rows.push(Rv32RamEventRow { + cycle: r.cycle, + pc: r.pc_before, + kind: Rv32RamEventKind::Write, + addr: e.addr, + prev_val: prev, + next_val: next, + }); + } + } + } + } + + Ok(Self { rows }) + } +} diff --git a/crates/neo-memory/src/riscv/lookups/cpu.rs b/crates/neo-memory/src/riscv/lookups/cpu.rs index 4c3094bb..6c315d7f 100644 --- a/crates/neo-memory/src/riscv/lookups/cpu.rs +++ b/crates/neo-memory/src/riscv/lookups/cpu.rs @@ -5,7 +5,6 @@ use super::decode::decode_instruction; use super::encode::encode_instruction; use super::isa::{BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode}; use super::tables::RiscvShoutTables; -use super::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM}; /// A RISC-V CPU that can be traced using Neo's VmCpu trait. /// @@ -85,10 +84,16 @@ impl RiscvCpu { } fn handle_ecall(&mut self) { - let call_id = self.get_reg(10) as u32; // a0 - if call_id != JOLT_CYCLE_TRACK_ECALL_NUM && call_id != JOLT_PRINT_ECALL_NUM { - self.halted = true; + self.halted = true; + } + + fn write_reg>(&mut self, twist: &mut T, reg: u8, value: u64) { + if reg == 0 { + return; } + let masked = self.mask_value(value); + twist.store_lane(super::REG_ID, reg as u64, masked, /*lane=*/ 0); + self.regs[reg as usize] = masked; } } @@ -152,17 +157,39 @@ impl neo_vm_trace::VmCpu for RiscvCpu { ) })?; + // -------------------------------------------------------------------- + // Regfile-as-Twist (REG_ID): always emit two register reads per step. + // + // Lane assignment (RV32 trace convention): + // - lane 0: read rs1_field + // - lane 1: read rs2_field + // -------------------------------------------------------------------- + let reg = super::REG_ID; + let rs1_field = ((instr_word_u32 >> 15) & 0x1f) as u64; + let rs2_field = ((instr_word_u32 >> 20) & 0x1f) as u64; + let rs2_addr = rs2_field; + + let rs1_val = self.mask_value(twist.load_lane(reg, rs1_field, /*lane=*/ 0)); + let rs2_val = self.mask_value(twist.load_lane(reg, rs2_addr, /*lane=*/ 1)); + + // Keep the CPU's register snapshot mirror consistent with Twist state. + self.regs[0] = 0; + if rs1_field != 0 { + self.regs[rs1_field as usize] = rs1_val; + } + if rs2_addr != 0 { + self.regs[rs2_addr as usize] = rs2_val; + } + // Default: advance PC by 4 let mut next_pc = self.pc.wrapping_add(4); let step_opcode: u32 = instr_word_u32; match instr { - RiscvInstruction::RAlu { op, rd, rs1, rs2 } => { - let rs1_val = self.get_reg(rs1); - let rs2_val = self.get_reg(rs2); - + RiscvInstruction::RAlu { op, rd, rs1: _, rs2: _ } => { match op { - // For RV32 B1, prove all M ops in-circuit (avoid implicit Shout tables). + // RV32 trace mode does not require Shout tables for RV32M semantics. + // (They are checked by the RV32M sidecar CCS; Shout is only used for the remainder-bound SLTU check.) RiscvOpcode::Mul | RiscvOpcode::Mulh | RiscvOpcode::Mulhu @@ -181,22 +208,22 @@ impl neo_vm_trace::VmCpu for RiscvCpu { match op { RiscvOpcode::Mul => { let result = rs1_u32.wrapping_mul(rs2_u32) as u64; - self.set_reg(rd, result); + self.write_reg(twist, rd, result); } RiscvOpcode::Mulh => { let product = (rs1_i32 as i64) * (rs2_i32 as i64); let result = (product >> 32) as i32 as u32; - self.set_reg(rd, result as u64); + self.write_reg(twist, rd, result as u64); } RiscvOpcode::Mulhu => { let product = (rs1_u32 as u64) * (rs2_u32 as u64); let result = (product >> 32) as u32; - self.set_reg(rd, result as u64); + self.write_reg(twist, rd, result as u64); } RiscvOpcode::Mulhsu => { let product = (rs1_i32 as i64) * (rs2_u32 as i64); let result = (product >> 32) as i32 as u32; - self.set_reg(rd, result as u64); + self.write_reg(twist, rd, result as u64); } RiscvOpcode::Div | RiscvOpcode::Rem => { let divisor_is_zero = rs2_u32 == 0; @@ -212,7 +239,7 @@ impl neo_vm_trace::VmCpu for RiscvCpu { RiscvOpcode::Rem => rem_i32 as u32, _ => unreachable!(), }; - self.set_reg(rd, result as u64); + self.write_reg(twist, rd, result as u64); // Record a single Shout event for the remainder bound, only when divisor != 0. if !divisor_is_zero { @@ -236,7 +263,7 @@ impl neo_vm_trace::VmCpu for RiscvCpu { RiscvOpcode::Remu => rem, _ => unreachable!(), }; - self.set_reg(rd, result); + self.write_reg(twist, rd, result); // Record a single Shout event for the remainder bound, only when divisor != 0. if divisor != 0 { @@ -253,13 +280,12 @@ impl neo_vm_trace::VmCpu for RiscvCpu { let shout_id = shout_tables.opcode_to_id(op); let index = interleave_bits(rs1_val, rs2_val) as u64; let result = shout.lookup(shout_id, index); - self.set_reg(rd, result); + self.write_reg(twist, rd, result); } } } - RiscvInstruction::IAlu { op, rd, rs1, imm } => { - let rs1_val = self.get_reg(rs1); + RiscvInstruction::IAlu { op, rd, rs1: _, imm } => { let imm_val = self.sign_extend_imm(imm); // Use Shout for the ALU operation @@ -267,16 +293,16 @@ impl neo_vm_trace::VmCpu for RiscvCpu { let index = interleave_bits(rs1_val, imm_val) as u64; let result = shout.lookup(shout_id, index); - self.set_reg(rd, result); + self.write_reg(twist, rd, result); } - RiscvInstruction::Load { op, rd, rs1, imm } => { - let base = self.get_reg(rs1); + RiscvInstruction::Load { op, rd, rs1: _, imm } => { + let base = rs1_val; let imm_val = self.sign_extend_imm(imm); let index = interleave_bits(base, imm_val) as u64; let addr = shout.lookup(add_shout_id, index); - // Twist RAM semantics (RV32 B1 / MVP): + // Twist RAM semantics (RV32 trace mode): // - Memory is byte-addressed, and `addr` is a byte address. // - Twist accesses are word-valued (XLEN bits), i.e. a `load/store` at `addr` // reads/writes the little-endian word window starting at `addr`. @@ -308,15 +334,20 @@ impl neo_vm_trace::VmCpu for RiscvCpu { value }; - self.set_reg(rd, self.mask_value(result)); + self.write_reg(twist, rd, result); } - RiscvInstruction::Store { op, rs1, rs2, imm } => { - let base = self.get_reg(rs1); + RiscvInstruction::Store { + op, + rs1: _, + rs2: _, + imm, + } => { + let base = rs1_val; let imm_val = self.sign_extend_imm(imm); let index = interleave_bits(base, imm_val) as u64; let addr = shout.lookup(add_shout_id, index); - let value = self.get_reg(rs2); + let value = rs2_val; // Mask value to store width let width = op.width_bytes(); @@ -340,10 +371,12 @@ impl neo_vm_trace::VmCpu for RiscvCpu { } } - RiscvInstruction::Branch { cond, rs1, rs2, imm } => { - let rs1_val = self.get_reg(rs1); - let rs2_val = self.get_reg(rs2); - + RiscvInstruction::Branch { + cond, + rs1: _, + rs2: _, + imm, + } => { // Use Shout for the comparison let shout_id = shout_tables.opcode_to_id(cond.to_shout_opcode()); let index = interleave_bits(rs1_val, rs2_val) as u64; @@ -355,8 +388,12 @@ impl neo_vm_trace::VmCpu for RiscvCpu { } let taken = match cond { - BranchCondition::Eq | BranchCondition::Ne | BranchCondition::Lt | BranchCondition::Ltu => cmp == 1, - BranchCondition::Ge | BranchCondition::Geu => cmp == 0, + // EQ/SLT/SLTU return 1 when the predicate holds. + BranchCondition::Eq | BranchCondition::Lt | BranchCondition::Ltu => cmp == 1, + // Inverted predicates: + // - Ne uses Eq lookup: taken = !(rs1 == rs2) + // - Ge/Geu use Slt/Sltu lookup: taken = !(rs1 < rs2) + BranchCondition::Ne | BranchCondition::Ge | BranchCondition::Geu => cmp == 0, }; if taken { let imm_u = self.sign_extend_imm(imm); @@ -366,14 +403,13 @@ impl neo_vm_trace::VmCpu for RiscvCpu { RiscvInstruction::Jal { rd, imm } => { // rd = pc + 4 (return address) - self.set_reg(rd, self.pc.wrapping_add(4)); + self.write_reg(twist, rd, self.pc.wrapping_add(4)); // pc = pc + imm let imm_u = self.sign_extend_imm(imm); next_pc = self.pc.wrapping_add(imm_u); } - RiscvInstruction::Jalr { rd, rs1, imm } => { - let rs1_val = self.get_reg(rs1); + RiscvInstruction::Jalr { rd, rs1: _, imm } => { let return_addr = self.pc.wrapping_add(4); // pc = (rs1 + imm) & !3 (MVP: no compressed instructions) @@ -384,13 +420,13 @@ impl neo_vm_trace::VmCpu for RiscvCpu { next_pc = sum & !3u64; // rd = return address - self.set_reg(rd, return_addr); + self.write_reg(twist, rd, return_addr); } RiscvInstruction::Lui { rd, imm } => { // rd = imm << 12 (upper 20 bits) let value = (imm as i64 as u64) << 12; - self.set_reg(rd, self.mask_value(value)); + self.write_reg(twist, rd, value); } RiscvInstruction::Auipc { rd, imm } => { @@ -398,42 +434,38 @@ impl neo_vm_trace::VmCpu for RiscvCpu { let imm_u = self.mask_value((imm as i64 as u64) << 12); let index = interleave_bits(self.pc, imm_u) as u64; let value = shout.lookup(add_shout_id, index); - self.set_reg(rd, value); + self.write_reg(twist, rd, value); } RiscvInstruction::Halt => { - // ECALL trap semantics (Jolt-style): no architectural effects, halt unless it's a known marker/print call. + // ECALL trap semantics: halt. self.handle_ecall(); } RiscvInstruction::Nop => {} // === RV64 W-suffix Operations === - RiscvInstruction::RAluw { op, rd, rs1, rs2 } => { - let rs1_val = self.get_reg(rs1); - let rs2_val = self.get_reg(rs2); - + RiscvInstruction::RAluw { op, rd, rs1: _, rs2: _ } => { let shout_id = shout_tables.opcode_to_id(op); let index = interleave_bits(rs1_val, rs2_val) as u64; let result = shout.lookup(shout_id, index); - self.set_reg(rd, result); + self.write_reg(twist, rd, result); } - RiscvInstruction::IAluw { op, rd, rs1, imm } => { - let rs1_val = self.get_reg(rs1); + RiscvInstruction::IAluw { op, rd, rs1: _, imm } => { let imm_val = self.sign_extend_imm(imm); let shout_id = shout_tables.opcode_to_id(op); let index = interleave_bits(rs1_val, imm_val) as u64; let result = shout.lookup(shout_id, index); - self.set_reg(rd, result); + self.write_reg(twist, rd, result); } // === A Extension: Atomics === - RiscvInstruction::LoadReserved { op, rd, rs1 } => { - let addr = self.get_reg(rs1); + RiscvInstruction::LoadReserved { op, rd, rs1: _ } => { + let addr = rs1_val; let value = twist.load(ram, addr); // Apply width and sign extension @@ -452,13 +484,13 @@ impl neo_vm_trace::VmCpu for RiscvCpu { value & mask }; - self.set_reg(rd, self.mask_value(result)); + self.write_reg(twist, rd, result); // Note: In a real implementation, we'd reserve the address here } - RiscvInstruction::StoreConditional { op, rd, rs1, rs2 } => { - let addr = self.get_reg(rs1); - let value = self.get_reg(rs2); + RiscvInstruction::StoreConditional { op, rd, rs1: _, rs2: _ } => { + let addr = rs1_val; + let value = rs2_val; // Mask value to store width let width = op.width_bytes(); @@ -473,16 +505,16 @@ impl neo_vm_trace::VmCpu for RiscvCpu { twist.store(ram, addr, store_value); // SC returns 0 on success (assuming reservation is valid in single-threaded mode) - self.set_reg(rd, 0); + self.write_reg(twist, rd, 0); } - RiscvInstruction::Amo { op, rd, rs1, rs2 } => { - let addr = self.get_reg(rs1); - let src = self.mask_value(self.get_reg(rs2)); + RiscvInstruction::Amo { op, rd, rs1: _, rs2: _ } => { + let addr = rs1_val; + let src = rs2_val; // Load original value let original = self.mask_value(twist.load(ram, addr)); - self.set_reg(rd, original); + self.write_reg(twist, rd, original); // Compute new value based on AMO operation let new_val = match op { diff --git a/crates/neo-memory/src/riscv/lookups/decode.rs b/crates/neo-memory/src/riscv/lookups/decode.rs index a14ed789..303a4feb 100644 --- a/crates/neo-memory/src/riscv/lookups/decode.rs +++ b/crates/neo-memory/src/riscv/lookups/decode.rs @@ -181,7 +181,7 @@ pub fn decode_instruction(instr: u32) -> Result { Ok(RiscvInstruction::Auipc { rd, imm }) } - // SYSTEM (1110011) - ECALL (Jolt-style trap; no EBREAK support) + // SYSTEM (1110011) - ECALL (trap/terminate in this VM) 0b1110011 => { if instr == 0x0000_0073 { Ok(RiscvInstruction::Halt) // ECALL (trap/terminate in this VM) @@ -190,7 +190,7 @@ pub fn decode_instruction(instr: u32) -> Result { } } - // MISC-MEM (0001111) - FENCE (FENCE.I unsupported to match Jolt tracer) + // MISC-MEM (0001111) - FENCE (FENCE.I unsupported) 0b0001111 => { if funct3 != 0b000 { return Err(format!("Unsupported MISC-MEM instruction: funct3={:#x}", funct3)); diff --git a/crates/neo-memory/src/riscv/lookups/isa.rs b/crates/neo-memory/src/riscv/lookups/isa.rs index 02fd460b..792c4a12 100644 --- a/crates/neo-memory/src/riscv/lookups/isa.rs +++ b/crates/neo-memory/src/riscv/lookups/isa.rs @@ -394,7 +394,8 @@ impl BranchCondition { pub fn to_shout_opcode(&self) -> RiscvOpcode { match self { BranchCondition::Eq => RiscvOpcode::Eq, - BranchCondition::Ne => RiscvOpcode::Neq, + // Represent BNE as EQ + invert (avoids a dedicated NEQ table/lane). + BranchCondition::Ne => RiscvOpcode::Eq, BranchCondition::Lt => RiscvOpcode::Slt, BranchCondition::Ge => RiscvOpcode::Slt, // BGE = !(rs1 < rs2) BranchCondition::Ltu => RiscvOpcode::Sltu, diff --git a/crates/neo-memory/src/riscv/lookups/memory.rs b/crates/neo-memory/src/riscv/lookups/memory.rs index a3806a30..32d2e41c 100644 --- a/crates/neo-memory/src/riscv/lookups/memory.rs +++ b/crates/neo-memory/src/riscv/lookups/memory.rs @@ -30,6 +30,10 @@ impl RiscvMemoryEvent { pub struct RiscvMemory { /// Memory contents (sparse representation). data: HashMap<(TwistId, u64), u8>, + /// Architectural register file contents (x0..x31), word-addressed. + /// + /// This is stored separately from `data` because registers are not byte-addressed. + regs: [u64; 32], /// Word size in bits (32 or 64). pub xlen: usize, } @@ -39,6 +43,7 @@ impl RiscvMemory { pub fn new(xlen: usize) -> Self { Self { data: HashMap::new(), + regs: [0u64; 32], xlen, } } @@ -123,6 +128,15 @@ impl RiscvMemory { impl Twist for RiscvMemory { fn load(&mut self, twist_id: TwistId, addr: u64) -> u64 { + if twist_id == super::REG_ID { + let idx = addr as usize; + debug_assert!(idx < 32, "REG_ID addr out of range: {}", idx); + if idx == 0 { + return 0; + } + return self.regs.get(idx).copied().unwrap_or(0); + } + let width = if twist_id == super::PROG_ID { // Program ROM fetch: always 32-bit instruction word (MVP: no compressed). 4 @@ -134,6 +148,19 @@ impl Twist for RiscvMemory { } fn store(&mut self, twist_id: TwistId, addr: u64, value: u64) { + if twist_id == super::REG_ID { + let idx = addr as usize; + debug_assert!(idx < 32, "REG_ID addr out of range: {}", idx); + if idx == 0 { + return; + } + let masked = if self.xlen == 32 { value as u32 as u64 } else { value }; + if let Some(dst) = self.regs.get_mut(idx) { + *dst = masked; + } + return; + } + let width = if twist_id == super::PROG_ID { 4 } else { self.xlen / 8 }; self.write(twist_id, addr, width, value); } diff --git a/crates/neo-memory/src/riscv/lookups/mod.rs b/crates/neo-memory/src/riscv/lookups/mod.rs index 4753eaec..65866a8b 100644 --- a/crates/neo-memory/src/riscv/lookups/mod.rs +++ b/crates/neo-memory/src/riscv/lookups/mod.rs @@ -7,7 +7,7 @@ //! //! # Proving integration scope (today) //! -//! The shared-bus RV32 B1 proving path assumes: +//! The shared-bus RV32 trace-wiring proving path assumes: //! - `xlen == 32` (RV32) //! - no compressed (RVC) instructions //! - 4-byte aligned PC and control-flow targets @@ -60,7 +60,7 @@ //! - Automatic detection of 16-bit vs 32-bit instructions //! //! ## System -//! - ECALL (Jolt-style markers), FENCE +//! - ECALL, FENCE //! - EBREAK and FENCE.I are not supported //! //! # Example @@ -98,12 +98,13 @@ use neo_vm_trace::TwistId; /// Canonical Twist instance id for RISC-V data RAM. pub const RAM_ID: TwistId = TwistId(0); -/// Canonical Twist instance id for the program ROM (B1 instruction fetch). +/// Canonical Twist instance id for the program ROM instruction fetch. pub const PROG_ID: TwistId = TwistId(1); -/// Jolt ECALL identifiers for marker/print syscalls. -pub const JOLT_CYCLE_TRACK_ECALL_NUM: u32 = 0xC7C1E; -pub const JOLT_PRINT_ECALL_NUM: u32 = 0x505249; +/// Canonical Twist instance id for the architectural register file (x0..x31). +/// +/// This is used by the RV32 trace-wiring circuit in "regfile-as-Twist" mode. +pub const REG_ID: TwistId = TwistId(2); pub use alu::{compute_op, lookup_entry}; pub use bits::{interleave_bits, uninterleave_bits}; diff --git a/crates/neo-memory/src/riscv/mod.rs b/crates/neo-memory/src/riscv/mod.rs index 32fa18ff..a3307a15 100644 --- a/crates/neo-memory/src/riscv/mod.rs +++ b/crates/neo-memory/src/riscv/mod.rs @@ -4,7 +4,9 @@ pub mod ccs; pub mod elf_loader; +pub mod exec_table; pub mod lookups; pub mod rom_init; -pub mod shard; pub mod shout_oracle; +pub mod sparse_access; +pub mod trace; diff --git a/crates/neo-memory/src/riscv/rom_init.rs b/crates/neo-memory/src/riscv/rom_init.rs index f284942c..01e7b6b0 100644 --- a/crates/neo-memory/src/riscv/rom_init.rs +++ b/crates/neo-memory/src/riscv/rom_init.rs @@ -74,7 +74,7 @@ pub fn prog_rom_layout_and_init_words( /// The ROM value at address `base + 4*i` is the little-endian `u32` formed from /// `program_bytes[4*i..4*i+4]`. /// -/// This matches the RV32 B1 step circuit convention: +/// This matches the RV32 trace-wiring convention: /// - instruction fetch address is the architectural PC (byte address), /// - instruction fetch value is a 32-bit word. pub fn prog_init_words( diff --git a/crates/neo-memory/src/riscv/shard.rs b/crates/neo-memory/src/riscv/shard.rs deleted file mode 100644 index 07f15364..00000000 --- a/crates/neo-memory/src/riscv/shard.rs +++ /dev/null @@ -1,84 +0,0 @@ -use p3_field::PrimeCharacteristicRing; -use p3_goldilocks::Goldilocks as F; - -use crate::riscv::ccs::{rv32_b1_step_linking_pairs, Rv32B1Layout}; - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Rv32BoundaryState { - pub pc0: F, - pub regs0: [F; 32], - pub pc_final: F, - pub regs_final: [F; 32], - pub halted_in: F, - pub halted_out: F, -} - -pub fn extract_boundary_state(layout: &Rv32B1Layout, x: &[F]) -> Result { - let required = [ - layout.pc0, - layout.regs0_start + 31, - layout.pc_final, - layout.regs_final_start + 31, - layout.halted_in, - layout.halted_out, - ]; - let max = required.into_iter().max().unwrap_or(0); - if max >= x.len() { - return Err(format!( - "public x too short for RV32 boundary extraction: need idx {max} but x.len()={}", - x.len() - )); - } - - let mut regs0 = [F::ZERO; 32]; - let mut regs_final = [F::ZERO; 32]; - for r in 0..32 { - regs0[r] = x[layout.regs0_start + r]; - regs_final[r] = x[layout.regs_final_start + r]; - } - - Ok(Rv32BoundaryState { - pc0: x[layout.pc0], - regs0, - pc_final: x[layout.pc_final], - regs_final, - halted_in: x[layout.halted_in], - halted_out: x[layout.halted_out], - }) -} - -pub fn check_rv32_b1_chunk_chaining(layout: &Rv32B1Layout, chunk_publics: &[&[F]]) -> Result<(), String> { - if chunk_publics.len() <= 1 { - return Ok(()); - } - - let pairs = rv32_b1_step_linking_pairs(layout); - for (i, (a, b)) in chunk_publics - .iter() - .zip(chunk_publics.iter().skip(1)) - .enumerate() - { - for &(out_idx, in_idx) in &pairs { - let out = a.get(out_idx).copied().ok_or_else(|| { - format!( - "chunk {i} public x too short for linking: need idx {out_idx} but x.len()={}", - a.len() - ) - })?; - let inn = b.get(in_idx).copied().ok_or_else(|| { - format!( - "chunk {} public x too short for linking: need idx {in_idx} but x.len()={}", - i + 1, - b.len() - ) - })?; - if out != inn { - return Err(format!( - "RV32 chunk linking mismatch at boundary {i}: x_i[{out_idx}] != x_{}[{in_idx}]", - i + 1 - )); - } - } - } - Ok(()) -} diff --git a/crates/neo-memory/src/riscv/sparse_access.rs b/crates/neo-memory/src/riscv/sparse_access.rs new file mode 100644 index 00000000..26cf50ae --- /dev/null +++ b/crates/neo-memory/src/riscv/sparse_access.rs @@ -0,0 +1,149 @@ +//! Sparse access representations for RISC-V sidecars. +//! +//! This module does **not** implement Jolt's full instruction-lookup protocol. +//! It only provides small, reusable building blocks inspired by Jolt's approach: +//! represent read-access patterns as sparse matrices over (address, cycle). +//! +//! These helpers are intended for future Tier-2.1+ work where Shout/ALU sidecars +//! move from packed "bus slices" toward true sparse read matrices (InstructionRa-like). + +use neo_math::K; +use p3_field::PrimeCharacteristicRing; + +use crate::mle::build_chi_table; +use crate::sparse_matrix::{SparseMat, SparseMatEntry}; + +use super::exec_table::Rv32ShoutEventTable; + +/// Build sparse `(addr, cycle)` matrices for RV32 Shout events. +/// +/// - `ra(addr, cycle)` is 1 for each executed lookup event. +/// - `val(addr, cycle)` is the lookup value for each executed event. +/// +/// Dimensions: +/// - address domain: 64 bits (`ell_addr = 64`), with `addr = event.key` (interleaved RV32 operands) +/// - cycle domain: `ell_cycle` bits, with `cycle = event.row_idx` (exec-table row index) +pub fn rv32_shout_event_table_to_sparse_ra_and_val( + events: &Rv32ShoutEventTable, + ell_cycle: usize, +) -> Result<(SparseMat, SparseMat), String> { + if ell_cycle >= 64 { + return Err(format!( + "rv32_shout_event_table_to_sparse_ra_and_val: ell_cycle={ell_cycle} too large for u64 cycle indices" + )); + } + let max_cycle = 1u64 + .checked_shl(ell_cycle as u32) + .ok_or_else(|| "rv32_shout_event_table_to_sparse_ra_and_val: 2^ell_cycle overflow".to_string())?; + + let mut ra_entries: Vec> = Vec::with_capacity(events.rows.len()); + let mut val_entries: Vec> = Vec::with_capacity(events.rows.len()); + + for row in events.rows.iter() { + let cycle = u64::try_from(row.row_idx) + .map_err(|_| "rv32_shout_event_table_to_sparse_ra_and_val: row_idx does not fit u64".to_string())?; + if cycle >= max_cycle { + return Err(format!( + "rv32_shout_event_table_to_sparse_ra_and_val: event row_idx {} out of range for ell_cycle={ell_cycle}", + row.row_idx + )); + } + + let addr = row.key; + ra_entries.push(SparseMatEntry { + row: addr, + col: cycle, + value: K::ONE, + }); + val_entries.push(SparseMatEntry { + row: addr, + col: cycle, + value: K::from_u64(row.value), + }); + } + + let ra = SparseMat::from_entries(/*ell_row=*/ 64, /*ell_col=*/ ell_cycle, ra_entries); + let val = SparseMat::from_entries(/*ell_row=*/ 64, /*ell_col=*/ ell_cycle, val_entries); + Ok((ra, val)) +} + +fn chi_at_u64_index(r: &[K], idx: u64) -> K { + let mut acc = K::ONE; + for (bit, &ri) in r.iter().enumerate() { + let is_one = ((idx >> bit) & 1) == 1; + acc *= if is_one { ri } else { K::ONE - ri }; + } + acc +} + +/// Evaluate the RV32 Shout event-table sparse matrices (RA and VAL) at `(r_addr, r_cycle)` +/// using a Jolt-style chunked address equality table. +/// +/// This is purely a helper for future "InstructionRa-like" protocols: it shows how to compute +/// `χ_{r_addr}(key)` without committing to `ell_addr=64` addr-bit columns. +pub fn rv32_shout_event_table_ra_val_mle_eval_chunked( + events: &Rv32ShoutEventTable, + r_addr: &[K], + r_cycle: &[K], + log_k_chunk: usize, +) -> Result<(K, K), String> { + if r_addr.len() != 64 { + return Err(format!( + "rv32_shout_event_table_ra_val_mle_eval_chunked: expected r_addr.len()=64, got {}", + r_addr.len() + )); + } + // This helper builds a dense χ table of length 2^log_k_chunk, so keep chunk sizes small. + // (Jolt commonly uses 8 or 16.) + if log_k_chunk == 0 || log_k_chunk > 16 { + return Err(format!( + "rv32_shout_event_table_ra_val_mle_eval_chunked: log_k_chunk must be in [1,16], got {log_k_chunk}" + )); + } + if 64 % log_k_chunk != 0 { + return Err(format!( + "rv32_shout_event_table_ra_val_mle_eval_chunked: log_k_chunk={log_k_chunk} must divide 64" + )); + } + if r_cycle.len() >= 64 { + return Err(format!( + "rv32_shout_event_table_ra_val_mle_eval_chunked: r_cycle.len()={} too large for u64 cycle indices", + r_cycle.len() + )); + } + + let mask: u64 = (1u64 << log_k_chunk) - 1; + let n_chunks = 64 / log_k_chunk; + + let mut eq_evals_by_chunk: Vec> = Vec::with_capacity(n_chunks); + for chunk in 0..n_chunks { + let start = chunk * log_k_chunk; + let end = start + log_k_chunk; + eq_evals_by_chunk.push(build_chi_table(&r_addr[start..end])); + } + + let mut ra_acc = K::ZERO; + let mut val_acc = K::ZERO; + for row in events.rows.iter() { + let cycle = u64::try_from(row.row_idx) + .map_err(|_| "rv32_shout_event_table_ra_val_mle_eval_chunked: row_idx does not fit u64".to_string())?; + let w_cycle = chi_at_u64_index(r_cycle, cycle); + + let mut w_addr = K::ONE; + for (chunk, eq_evals) in eq_evals_by_chunk.iter().enumerate() { + let shift = chunk * log_k_chunk; + let idx = ((row.key >> shift) & mask) as usize; + let w = eq_evals + .get(idx) + .copied() + .ok_or_else(|| "rv32_shout_event_table_ra_val_mle_eval_chunked: chunk idx out of range".to_string())?; + w_addr *= w; + } + + let w = w_cycle * w_addr; + ra_acc += w; + val_acc += K::from_u64(row.value) * w; + } + + Ok((ra_acc, val_acc)) +} diff --git a/crates/neo-memory/src/riscv/trace/air.rs b/crates/neo-memory/src/riscv/trace/air.rs new file mode 100644 index 00000000..373954fe --- /dev/null +++ b/crates/neo-memory/src/riscv/trace/air.rs @@ -0,0 +1,149 @@ +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +use super::{layout::Rv32TraceLayout, witness::Rv32TraceWitness}; + +#[derive(Clone, Debug)] +pub struct Rv32TraceAir { + pub layout: Rv32TraceLayout, +} + +impl Rv32TraceAir { + pub fn new() -> Self { + Self { + layout: Rv32TraceLayout::new(), + } + } + + #[inline] + fn is_zero(x: F) -> bool { + x == F::ZERO + } + + #[inline] + fn bool_check(x: F) -> F { + x * (x - F::ONE) + } + + #[inline] + fn gated_zero(gate: F, x: F) -> F { + gate * x + } + + pub fn assert_satisfied(&self, wit: &Rv32TraceWitness) -> Result<(), String> { + let l = &self.layout; + if wit.cols.len() != l.cols { + return Err(format!( + "trace witness width mismatch: got {} cols, expected {}", + wit.cols.len(), + l.cols + )); + } + for (c, col) in wit.cols.iter().enumerate() { + if col.len() != wit.t { + return Err(format!( + "trace witness column length mismatch at col {c}: got {}, expected {}", + col.len(), + wit.t + )); + } + } + + let col = |c: usize, i: usize| -> F { wit.cols[c][i] }; + + // Row-wise constraints. + for i in 0..wit.t { + let one = col(l.one, i); + if one != F::ONE { + return Err(format!("row {i}: one != 1")); + } + + let active = col(l.active, i); + let halted = col(l.halted, i); + let shout_has_lookup = col(l.shout_has_lookup, i); + + // Booleans. + for (name, v) in [ + ("active", active), + ("halted", halted), + ("shout_has_lookup", shout_has_lookup), + ] { + let e = Self::bool_check(v); + if !Self::is_zero(e) { + return Err(format!("row {i}: {name} not boolean")); + } + } + // Padding invariants: inactive rows must not carry "hidden" values. + let inv_active = F::ONE - active; + for (name, c) in [ + ("instr_word", l.instr_word), + ("rs1_addr", l.rs1_addr), + ("rs1_val", l.rs1_val), + ("rs2_addr", l.rs2_addr), + ("rs2_val", l.rs2_val), + ("rd_addr", l.rd_addr), + ("rd_val", l.rd_val), + ("ram_addr", l.ram_addr), + ("ram_rv", l.ram_rv), + ("ram_wv", l.ram_wv), + ("shout_has_lookup", l.shout_has_lookup), + ("shout_val", l.shout_val), + ("shout_lhs", l.shout_lhs), + ("shout_rhs", l.shout_rhs), + ("jalr_drop_bit", l.jalr_drop_bit), + ] { + let e = Self::gated_zero(inv_active, col(c, i)); + if !Self::is_zero(e) { + return Err(format!("row {i}: inactive padding violated ({name} != 0)")); + } + } + + // Shout padding: if no lookup, the lookup output must be 0. + { + if !Self::is_zero(Self::gated_zero(F::ONE - shout_has_lookup, col(l.shout_val, i))) { + return Err(format!("row {i}: shout_val must be 0 when shout_has_lookup=0")); + } + if !Self::is_zero(Self::gated_zero(F::ONE - shout_has_lookup, col(l.shout_lhs, i))) { + return Err(format!("row {i}: shout_lhs must be 0 when shout_has_lookup=0")); + } + if !Self::is_zero(Self::gated_zero(F::ONE - shout_has_lookup, col(l.shout_rhs, i))) { + return Err(format!("row {i}: shout_rhs must be 0 when shout_has_lookup=0")); + } + } + } + + // Transition constraints. + for i in 0..wit.t.saturating_sub(1) { + let e = col(l.pc_after, i) - col(l.pc_before, i + 1); + if !Self::is_zero(e) { + return Err(format!("pc chain mismatch at row {i}")); + } + + let e = col(l.cycle, i + 1) - (col(l.cycle, i) + F::ONE); + if !Self::is_zero(e) { + return Err(format!("cycle chain mismatch at row {i}")); + } + + // Once inactive, remain inactive. + let a0 = col(l.active, i); + let a1 = col(l.active, i + 1); + if !Self::is_zero(a1 * (F::ONE - a0)) { + return Err(format!("active monotonicity violated at row {i}")); + } + + // Once halted, remain halted. + let h0 = col(l.halted, i); + let h1 = col(l.halted, i + 1); + if !Self::is_zero(h0 * (F::ONE - h1)) { + return Err(format!("halted monotonicity violated at row {i}")); + } + + // HALT terminates execution: halted[i] => active[i+1] == 0. + if !Self::is_zero(h0 * a1) { + return Err(format!("halted tail quiescence violated at row {i}")); + } + } + + Ok(()) + } +} diff --git a/crates/neo-memory/src/riscv/trace/decode_lookup.rs b/crates/neo-memory/src/riscv/trace/decode_lookup.rs new file mode 100644 index 00000000..e584a8bd --- /dev/null +++ b/crates/neo-memory/src/riscv/trace/decode_lookup.rs @@ -0,0 +1,453 @@ +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +/// Base lookup table id for decode-column lookup families in shared-bus mode. +/// +/// Table id for decode column `c` is `RV32_TRACE_DECODE_LOOKUP_TABLE_BASE + c`. +pub const RV32_TRACE_DECODE_LOOKUP_TABLE_BASE: u32 = 0x5256_4400; +/// Base address-group id for decode lookup lanes. +pub const RV32_TRACE_DECODE_ADDR_GROUP_BASE: u32 = 0x5256_4A00; + +#[derive(Clone, Debug)] +pub struct Rv32DecodeSidecarLayout { + pub cols: usize, + pub opcode: usize, + pub funct3: usize, + pub funct7: usize, + pub rd: usize, + pub rs1: usize, + pub rs2: usize, + pub rd_has_write: usize, + pub ram_has_read: usize, + pub ram_has_write: usize, + pub shout_table_id: usize, + pub op_lui: usize, + pub op_auipc: usize, + pub op_jal: usize, + pub op_jalr: usize, + pub op_branch: usize, + pub op_load: usize, + pub op_store: usize, + pub op_alu_imm: usize, + pub op_alu_reg: usize, + pub op_misc_mem: usize, + pub op_system: usize, + pub op_amo: usize, + pub op_lui_write: usize, + pub op_auipc_write: usize, + pub op_jal_write: usize, + pub op_jalr_write: usize, + pub op_alu_imm_write: usize, + pub op_alu_reg_write: usize, + pub is_lb_write: usize, + pub is_lbu_write: usize, + pub is_lh_write: usize, + pub is_lhu_write: usize, + pub is_lw_write: usize, + pub funct3_is: [usize; 8], + pub alu_reg_table_delta: usize, + pub alu_imm_table_delta: usize, + pub alu_imm_shift_rhs_delta: usize, + pub imm_i: usize, + pub imm_s: usize, + pub imm_b: usize, + pub imm_j: usize, + pub rd_bit: [usize; 5], + pub funct3_bit: [usize; 3], + pub rs1_bit: [usize; 5], + pub rs2_bit: [usize; 5], + pub funct7_bit: [usize; 7], + pub rd_is_zero_01: usize, + pub rd_is_zero_012: usize, + pub rd_is_zero_0123: usize, + pub rd_is_zero: usize, +} + +impl Rv32DecodeSidecarLayout { + pub fn new() -> Self { + let mut next = 0usize; + let mut take = || { + let out = next; + next += 1; + out + }; + let opcode = take(); + let funct3 = take(); + let funct7 = take(); + let rd = take(); + let rs1 = take(); + let rs2 = take(); + let rd_has_write = take(); + let ram_has_read = take(); + let ram_has_write = take(); + let shout_table_id = take(); + let op_lui = take(); + let op_auipc = take(); + let op_jal = take(); + let op_jalr = take(); + let op_branch = take(); + let op_load = take(); + let op_store = take(); + let op_alu_imm = take(); + let op_alu_reg = take(); + let op_misc_mem = take(); + let op_system = take(); + let op_amo = take(); + let op_lui_write = take(); + let op_auipc_write = take(); + let op_jal_write = take(); + let op_jalr_write = take(); + let op_alu_imm_write = take(); + let op_alu_reg_write = take(); + let is_lb_write = take(); + let is_lbu_write = take(); + let is_lh_write = take(); + let is_lhu_write = take(); + let is_lw_write = take(); + let funct3_is_0 = take(); + let funct3_is_1 = take(); + let funct3_is_2 = take(); + let funct3_is_3 = take(); + let funct3_is_4 = take(); + let funct3_is_5 = take(); + let funct3_is_6 = take(); + let funct3_is_7 = take(); + let alu_reg_table_delta = take(); + let alu_imm_table_delta = take(); + let alu_imm_shift_rhs_delta = take(); + let imm_i = take(); + let imm_s = take(); + let imm_b = take(); + let imm_j = take(); + let rd_b0 = take(); + let rd_b1 = take(); + let rd_b2 = take(); + let rd_b3 = take(); + let rd_b4 = take(); + let funct3_b0 = take(); + let funct3_b1 = take(); + let funct3_b2 = take(); + let rs1_b0 = take(); + let rs1_b1 = take(); + let rs1_b2 = take(); + let rs1_b3 = take(); + let rs1_b4 = take(); + let rs2_b0 = take(); + let rs2_b1 = take(); + let rs2_b2 = take(); + let rs2_b3 = take(); + let rs2_b4 = take(); + let funct7_b0 = take(); + let funct7_b1 = take(); + let funct7_b2 = take(); + let funct7_b3 = take(); + let funct7_b4 = take(); + let funct7_b5 = take(); + let funct7_b6 = take(); + let rd_is_zero_01 = take(); + let rd_is_zero_012 = take(); + let rd_is_zero_0123 = take(); + let rd_is_zero = take(); + debug_assert_eq!(next, 77); + Self { + cols: next, + opcode, + funct3, + funct7, + rd, + rs1, + rs2, + rd_has_write, + ram_has_read, + ram_has_write, + shout_table_id, + op_lui, + op_auipc, + op_jal, + op_jalr, + op_branch, + op_load, + op_store, + op_alu_imm, + op_alu_reg, + op_misc_mem, + op_system, + op_amo, + op_lui_write, + op_auipc_write, + op_jal_write, + op_jalr_write, + op_alu_imm_write, + op_alu_reg_write, + is_lb_write, + is_lbu_write, + is_lh_write, + is_lhu_write, + is_lw_write, + funct3_is: [ + funct3_is_0, + funct3_is_1, + funct3_is_2, + funct3_is_3, + funct3_is_4, + funct3_is_5, + funct3_is_6, + funct3_is_7, + ], + alu_reg_table_delta, + alu_imm_table_delta, + alu_imm_shift_rhs_delta, + imm_i, + imm_s, + imm_b, + imm_j, + rd_bit: [rd_b0, rd_b1, rd_b2, rd_b3, rd_b4], + funct3_bit: [funct3_b0, funct3_b1, funct3_b2], + rs1_bit: [rs1_b0, rs1_b1, rs1_b2, rs1_b3, rs1_b4], + rs2_bit: [rs2_b0, rs2_b1, rs2_b2, rs2_b3, rs2_b4], + funct7_bit: [ + funct7_b0, funct7_b1, funct7_b2, funct7_b3, funct7_b4, funct7_b5, funct7_b6, + ], + rd_is_zero_01, + rd_is_zero_012, + rd_is_zero_0123, + rd_is_zero, + } + } +} + +#[inline] +pub fn rv32_decode_lookup_backed_cols(layout: &Rv32DecodeSidecarLayout) -> Vec { + let mut out = Vec::with_capacity(56); + out.push(layout.opcode); + out.push(layout.rs2); + out.push(layout.rd_has_write); + out.push(layout.ram_has_read); + out.push(layout.ram_has_write); + out.push(layout.shout_table_id); + out.extend_from_slice(&[ + layout.op_lui, + layout.op_auipc, + layout.op_jal, + layout.op_jalr, + layout.op_branch, + layout.op_load, + layout.op_store, + layout.op_alu_imm, + layout.op_alu_reg, + layout.op_misc_mem, + layout.op_system, + layout.op_amo, + ]); + out.extend_from_slice(&layout.funct3_is); + out.extend_from_slice(&[layout.imm_i, layout.imm_s, layout.imm_b, layout.imm_j]); + out.extend_from_slice(&layout.rd_bit); + out.extend_from_slice(&layout.funct3_bit); + out.extend_from_slice(&layout.rs1_bit); + out.extend_from_slice(&layout.rs2_bit); + out.extend_from_slice(&layout.funct7_bit); + out.push(layout.rd_is_zero); + out +} + +#[inline] +pub const fn rv32_decode_lookup_table_id_for_col(col: usize) -> u32 { + RV32_TRACE_DECODE_LOOKUP_TABLE_BASE + col as u32 +} + +#[inline] +pub const fn rv32_is_decode_lookup_table_id(table_id: u32) -> bool { + table_id >= RV32_TRACE_DECODE_LOOKUP_TABLE_BASE && table_id < RV32_TRACE_DECODE_LOOKUP_TABLE_BASE + 77 +} + +#[inline] +pub fn rv32_decode_lookup_addr_group_for_table_id(table_id: u32) -> Option { + if !rv32_is_decode_lookup_table_id(table_id) { + return None; + } + let col_id = (table_id - RV32_TRACE_DECODE_LOOKUP_TABLE_BASE) as usize; + let layout = Rv32DecodeSidecarLayout::new(); + let backed_cols = rv32_decode_lookup_backed_cols(&layout); + backed_cols + .iter() + .any(|&c| c == col_id) + .then_some(RV32_TRACE_DECODE_ADDR_GROUP_BASE) +} + +#[inline] +fn sign_extend_to_u32(value: u32, bits: u32) -> u32 { + debug_assert!(bits > 0 && bits <= 32); + let shift = 32 - bits; + (((value << shift) as i32) >> shift) as u32 +} + +#[inline] +fn imm_i_from_word(instr_word: u32) -> u32 { + sign_extend_to_u32((instr_word >> 20) & 0x0fff, 12) +} + +#[inline] +fn imm_s_from_word(instr_word: u32) -> u32 { + let imm = ((instr_word >> 7) & 0x1f) | (((instr_word >> 25) & 0x7f) << 5); + sign_extend_to_u32(imm, 12) +} + +#[inline] +fn imm_b_from_word(instr_word: u32) -> u32 { + let imm = (((instr_word >> 31) & 0x1) << 12) + | (((instr_word >> 7) & 0x1) << 11) + | (((instr_word >> 25) & 0x3f) << 5) + | (((instr_word >> 8) & 0xf) << 1); + sign_extend_to_u32(imm, 13) +} + +#[inline] +fn imm_j_from_word(instr_word: u32) -> u32 { + let imm = (((instr_word >> 31) & 0x1) << 20) + | (((instr_word >> 12) & 0xff) << 12) + | (((instr_word >> 20) & 0x1) << 11) + | (((instr_word >> 21) & 0x3ff) << 1); + sign_extend_to_u32(imm, 21) +} + +#[inline] +fn opcode_writes_rd(opcode_u64: u64) -> bool { + matches!(opcode_u64, 0x37 | 0x17 | 0x6F | 0x67 | 0x03 | 0x13 | 0x33) +} + +pub fn rv32_decode_lookup_backed_row_from_instr_word( + layout: &Rv32DecodeSidecarLayout, + instr_word: u32, + active: bool, +) -> Vec { + let mut row = vec![F::ZERO; layout.cols]; + let opcode_u64 = (instr_word & 0x7f) as u64; + let funct3_u64 = ((instr_word >> 12) & 0x7) as u64; + let funct7_u64 = ((instr_word >> 25) & 0x7f) as u64; + let rd_u64 = ((instr_word >> 7) & 0x1f) as u64; + let rs1_u64 = ((instr_word >> 15) & 0x1f) as u64; + let rs2_u64 = ((instr_word >> 20) & 0x1f) as u64; + + row[layout.opcode] = F::from_u64(opcode_u64); + row[layout.funct3] = F::from_u64(funct3_u64); + row[layout.funct7] = F::from_u64(funct7_u64); + row[layout.rd] = F::from_u64(rd_u64); + row[layout.rs1] = F::from_u64(rs1_u64); + row[layout.rs2] = F::from_u64(rs2_u64); + row[layout.imm_i] = F::from_u64(imm_i_from_word(instr_word) as u64); + row[layout.imm_s] = F::from_u64(imm_s_from_word(instr_word) as u64); + row[layout.imm_b] = F::from_u64(imm_b_from_word(instr_word) as u64); + row[layout.imm_j] = F::from_u64(imm_j_from_word(instr_word) as u64); + for (k, &bit_col) in layout.rd_bit.iter().enumerate() { + row[bit_col] = F::from_u64((rd_u64 >> k) & 1); + } + for (k, &bit_col) in layout.funct3_bit.iter().enumerate() { + row[bit_col] = F::from_u64((funct3_u64 >> k) & 1); + } + for (k, &bit_col) in layout.rs1_bit.iter().enumerate() { + row[bit_col] = F::from_u64((rs1_u64 >> k) & 1); + } + for (k, &bit_col) in layout.rs2_bit.iter().enumerate() { + row[bit_col] = F::from_u64((rs2_u64 >> k) & 1); + } + for (k, &bit_col) in layout.funct7_bit.iter().enumerate() { + row[bit_col] = F::from_u64((funct7_u64 >> k) & 1); + } + let one_minus_b0 = F::ONE - row[layout.rd_bit[0]]; + let one_minus_b1 = F::ONE - row[layout.rd_bit[1]]; + let one_minus_b2 = F::ONE - row[layout.rd_bit[2]]; + let one_minus_b3 = F::ONE - row[layout.rd_bit[3]]; + let one_minus_b4 = F::ONE - row[layout.rd_bit[4]]; + row[layout.rd_is_zero_01] = one_minus_b0 * one_minus_b1; + row[layout.rd_is_zero_012] = row[layout.rd_is_zero_01] * one_minus_b2; + row[layout.rd_is_zero_0123] = row[layout.rd_is_zero_012] * one_minus_b3; + row[layout.rd_is_zero] = row[layout.rd_is_zero_0123] * one_minus_b4; + + let is = |op: u64| if opcode_u64 == op { F::ONE } else { F::ZERO }; + row[layout.op_lui] = is(0x37); + row[layout.op_auipc] = is(0x17); + row[layout.op_jal] = is(0x6F); + row[layout.op_jalr] = is(0x67); + row[layout.op_branch] = is(0x63); + row[layout.op_load] = is(0x03); + row[layout.op_store] = is(0x23); + row[layout.op_alu_imm] = is(0x13); + row[layout.op_alu_reg] = is(0x33); + row[layout.op_misc_mem] = is(0x0F); + row[layout.op_system] = is(0x73); + row[layout.op_amo] = is(0x2F); + + let rd_has_write_f = if opcode_writes_rd(opcode_u64) && rd_u64 != 0 { + F::ONE + } else { + F::ZERO + }; + row[layout.rd_has_write] = rd_has_write_f; + row[layout.op_lui_write] = row[layout.op_lui] * rd_has_write_f; + row[layout.op_auipc_write] = row[layout.op_auipc] * rd_has_write_f; + row[layout.op_jal_write] = row[layout.op_jal] * rd_has_write_f; + row[layout.op_jalr_write] = row[layout.op_jalr] * rd_has_write_f; + row[layout.op_alu_imm_write] = row[layout.op_alu_imm] * rd_has_write_f; + row[layout.op_alu_reg_write] = row[layout.op_alu_reg] * rd_has_write_f; + + let is_load = opcode_u64 == 0x03; + let is_lb = is_load && funct3_u64 == 0b000; + let is_lh = is_load && funct3_u64 == 0b001; + let is_lw = is_load && funct3_u64 == 0b010; + let is_lbu = is_load && funct3_u64 == 0b100; + let is_lhu = is_load && funct3_u64 == 0b101; + let flag = |on: bool| if on { F::ONE } else { F::ZERO }; + row[layout.is_lb_write] = flag(is_lb) * rd_has_write_f; + row[layout.is_lbu_write] = flag(is_lbu) * rd_has_write_f; + row[layout.is_lh_write] = flag(is_lh) * rd_has_write_f; + row[layout.is_lhu_write] = flag(is_lhu) * rd_has_write_f; + row[layout.is_lw_write] = flag(is_lw) * rd_has_write_f; + let is_store = opcode_u64 == 0x23; + let is_sb = is_store && funct3_u64 == 0b000; + let is_sh = is_store && funct3_u64 == 0b001; + row[layout.ram_has_read] = if is_load || is_sb || is_sh { F::ONE } else { F::ZERO }; + row[layout.ram_has_write] = if is_store { F::ONE } else { F::ZERO }; + + for (k, &f3_col) in layout.funct3_is.iter().enumerate() { + row[f3_col] = if active && funct3_u64 == k as u64 { + F::ONE + } else { + F::ZERO + }; + } + + let funct7_b5 = (funct7_u64 >> 5) & 1; + let f3_is_0 = if active && funct3_u64 == 0 { 1 } else { 0 }; + let f3_is_5 = if active && funct3_u64 == 5 { 1 } else { 0 }; + let alu_table_base: u64 = match funct3_u64 { + 0 => 3, + 1 => 7, + 2 => 5, + 3 => 6, + 4 => 1, + 5 => 8, + 6 => 2, + _ => 0, + }; + let branch_table_expected: u64 = + 10 - 5 * ((funct3_u64 >> 2) & 1) + (((funct3_u64 >> 1) & 1) * ((funct3_u64 >> 2) & 1)); + row[layout.shout_table_id] = if opcode_u64 == 0x33 { + F::from_u64(alu_table_base + (funct7_b5 * (f3_is_0 + f3_is_5))) + } else if opcode_u64 == 0x13 { + F::from_u64(alu_table_base + (funct7_b5 * f3_is_5)) + } else if opcode_u64 == 0x63 { + F::from_u64(branch_table_expected) + } else if matches!(opcode_u64, 0x03 | 0x23 | 0x67 | 0x17) { + // LOAD/STORE/JALR/AUIPC use ADD shout semantics in the current trace runner. + F::from_u64(3) + } else { + F::ZERO + }; + row[layout.alu_reg_table_delta] = F::from_u64(funct7_b5 * (f3_is_0 + f3_is_5)); + row[layout.alu_imm_table_delta] = F::from_u64(funct7_b5 * f3_is_5); + + let shift_f3_sel = row[layout.funct3_is[1]] + row[layout.funct3_is[5]]; + row[layout.alu_imm_shift_rhs_delta] = shift_f3_sel * (F::from_u64(rs2_u64) - row[layout.imm_i]); + + row +} diff --git a/crates/neo-memory/src/riscv/trace/layout.rs b/crates/neo-memory/src/riscv/trace/layout.rs new file mode 100644 index 00000000..c3f52ca4 --- /dev/null +++ b/crates/neo-memory/src/riscv/trace/layout.rs @@ -0,0 +1,96 @@ +#[derive(Clone, Debug)] +pub struct Rv32TraceLayout { + pub cols: usize, + + // Core control / fetch. + pub one: usize, + pub active: usize, + pub halted: usize, + pub cycle: usize, + pub pc_before: usize, + pub pc_after: usize, + pub instr_word: usize, + + // Regfile view (REG Twist). + pub rs1_addr: usize, + pub rs1_val: usize, + pub rs2_addr: usize, + pub rs2_val: usize, + pub rd_addr: usize, + pub rd_val: usize, + + // RAM view (RAM Twist, normalized to at most 1R + 1W per row). + pub ram_addr: usize, + pub ram_rv: usize, + pub ram_wv: usize, + + // Shout view (single fixed-lane per row; output-only for now). + pub shout_has_lookup: usize, + pub shout_val: usize, + pub shout_lhs: usize, + pub shout_rhs: usize, + pub jalr_drop_bit: usize, +} + +impl Rv32TraceLayout { + pub fn new() -> Self { + let mut next = 0usize; + let mut take = || { + let out = next; + next += 1; + out + }; + + let one = take(); + let active = take(); + let halted = take(); + let cycle = take(); + let pc_before = take(); + let pc_after = take(); + let instr_word = take(); + + let rs1_addr = take(); + let rs1_val = take(); + let rs2_addr = take(); + let rs2_val = take(); + let rd_addr = take(); + let rd_val = take(); + + let ram_addr = take(); + let ram_rv = take(); + let ram_wv = take(); + + let shout_has_lookup = take(); + let shout_val = take(); + let shout_lhs = take(); + let shout_rhs = take(); + let jalr_drop_bit = take(); + + debug_assert_eq!(next, 21, "RV32 trace width drift after decode-helper offload"); + + Self { + cols: next, + one, + active, + halted, + cycle, + pc_before, + pc_after, + instr_word, + rs1_addr, + rs1_val, + rs2_addr, + rs2_val, + rd_addr, + rd_val, + ram_addr, + ram_rv, + ram_wv, + shout_has_lookup, + shout_val, + shout_lhs, + shout_rhs, + jalr_drop_bit, + } + } +} diff --git a/crates/neo-memory/src/riscv/trace/mod.rs b/crates/neo-memory/src/riscv/trace/mod.rs new file mode 100644 index 00000000..45c00e55 --- /dev/null +++ b/crates/neo-memory/src/riscv/trace/mod.rs @@ -0,0 +1,55 @@ +pub mod air; +pub mod decode_lookup; +pub mod layout; +pub mod sidecar_extract; +pub mod width_sidecar; +pub mod witness; + +pub use air::Rv32TraceAir; +pub use decode_lookup::{ + rv32_decode_lookup_addr_group_for_table_id, rv32_decode_lookup_backed_cols, + rv32_decode_lookup_backed_row_from_instr_word, rv32_decode_lookup_table_id_for_col, rv32_is_decode_lookup_table_id, + Rv32DecodeSidecarLayout, RV32_TRACE_DECODE_LOOKUP_TABLE_BASE, +}; +pub use layout::Rv32TraceLayout; +pub use sidecar_extract::{ + extract_shout_lanes_over_time, extract_twist_lanes_over_time, ShoutLaneOverTime, TraceTwistLanesOverTime, + TwistLaneOverTime, +}; +pub use width_sidecar::{ + rv32_is_width_lookup_table_id, rv32_width_lookup_addr_group_for_table_id, rv32_width_lookup_backed_cols, + rv32_width_lookup_table_id_for_col, rv32_width_sidecar_witness_from_exec_table, Rv32WidthSidecarLayout, + Rv32WidthSidecarWitness, RV32_TRACE_WIDTH_LOOKUP_TABLE_BASE, +}; +pub use witness::Rv32TraceWitness; + +/// Shared-address group id for canonical RV32 opcode Shout tables (table_id 0..=19). +/// +/// These families all use the same interleaved `(lhs,rhs)` key width (`ell_addr=64`), +/// so in RV32 trace shared-bus mode they can share one addr-bit range. +pub const RV32_TRACE_OPCODE_ADDR_GROUP: u32 = 0x5256_4100; +/// Shared selector-group id for decode lookup families (table_id range at `RV32_TRACE_DECODE_LOOKUP_TABLE_BASE`). +pub const RV32_TRACE_DECODE_SELECTOR_GROUP: u32 = 0x5256_4B00; +/// Shared selector-group id for width lookup families (table_id range at `RV32_TRACE_WIDTH_LOOKUP_TABLE_BASE`). +pub const RV32_TRACE_WIDTH_SELECTOR_GROUP: u32 = 0x5256_5B00; + +#[inline] +pub fn rv32_trace_lookup_addr_group_for_table_id(table_id: u32) -> Option { + if table_id <= 19 { + Some(RV32_TRACE_OPCODE_ADDR_GROUP) + } else { + rv32_decode_lookup_addr_group_for_table_id(table_id) + .or_else(|| rv32_width_lookup_addr_group_for_table_id(table_id)) + } +} + +#[inline] +pub fn rv32_trace_lookup_selector_group_for_table_id(table_id: u32) -> Option { + if rv32_is_decode_lookup_table_id(table_id) { + Some(RV32_TRACE_DECODE_SELECTOR_GROUP) + } else if rv32_is_width_lookup_table_id(table_id) { + Some(RV32_TRACE_WIDTH_SELECTOR_GROUP) + } else { + None + } +} diff --git a/crates/neo-memory/src/riscv/trace/sidecar_extract.rs b/crates/neo-memory/src/riscv/trace/sidecar_extract.rs new file mode 100644 index 00000000..be1a3bcb --- /dev/null +++ b/crates/neo-memory/src/riscv/trace/sidecar_extract.rs @@ -0,0 +1,328 @@ +use std::collections::HashMap; + +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +use crate::riscv::exec_table::Rv32ExecTable; +use crate::riscv::lookups::{interleave_bits, uninterleave_bits, RiscvOpcode, RiscvShoutTables}; + +#[derive(Clone, Debug)] +pub struct TwistLaneOverTime { + pub has_read: Vec, + pub ra: Vec, + pub rv: Vec, + pub has_write: Vec, + pub wa: Vec, + pub wv: Vec, + pub inc_at_write_addr: Vec, +} + +impl TwistLaneOverTime { + fn new_zero(t: usize) -> Self { + Self { + has_read: vec![false; t], + ra: vec![0; t], + rv: vec![0; t], + has_write: vec![false; t], + wa: vec![0; t], + wv: vec![0; t], + inc_at_write_addr: vec![F::ZERO; t], + } + } +} + +#[derive(Clone, Debug)] +pub struct TraceTwistLanesOverTime { + pub prog: TwistLaneOverTime, + pub reg_lane0: TwistLaneOverTime, + pub reg_lane1: TwistLaneOverTime, + pub ram: TwistLaneOverTime, +} + +#[derive(Clone, Debug)] +pub struct ShoutLaneOverTime { + pub has_lookup: Vec, + pub key: Vec, + pub value: Vec, +} + +impl ShoutLaneOverTime { + fn new_zero(t: usize) -> Self { + Self { + has_lookup: vec![false; t], + key: vec![0; t], + value: vec![0; t], + } + } +} + +/// Extract fixed-lane Twist-style memories over time from `Rv32ExecTable`. +/// +/// Layout/policy: +/// - PROG: lane0 read-only (exactly one read per active row) +/// - REG: lane0 reads rs1 + optional write rd, lane1 reads rs2 (exactly one read per active row) +/// - RAM: lane0 supports at most 1 read + 1 write per active row, both must share the same addr +/// +/// `init_regs` and `init_ram` are used to compute `inc_at_write_addr` and to sanity-check read values. +pub fn extract_twist_lanes_over_time( + exec: &Rv32ExecTable, + init_regs: &HashMap, + init_ram: &HashMap, + ram_ell_addr: usize, +) -> Result { + let t = exec.rows.len(); + + // Build REG state for `inc_at_write_addr`. + let mut regs = [0u64; 32]; + for (&addr, &value) in init_regs { + if addr >= 32 { + return Err(format!("trace extract: reg init addr out of range: addr={addr}")); + } + if addr == 0 && value != 0 { + return Err("trace extract: reg init must keep x0 == 0".into()); + } + regs[addr as usize] = value; + } + + // Build RAM state for `inc_at_write_addr` and read-value checks. + if ram_ell_addr > 64 { + return Err(format!( + "trace extract: RAM ell_addr too large for u64 addressing: ell_addr={ram_ell_addr}" + )); + } + let mut ram: HashMap = HashMap::new(); + for (&addr, &value) in init_ram { + if ram_ell_addr < 64 && (addr >> ram_ell_addr) != 0 { + return Err(format!( + "trace extract: RAM init addr out of range for ell_addr={ram_ell_addr}: addr={addr}" + )); + } + if value != 0 { + ram.insert(addr, value); + } + } + + let mut prog = TwistLaneOverTime::new_zero(t); + let mut reg0 = TwistLaneOverTime::new_zero(t); + let mut reg1 = TwistLaneOverTime::new_zero(t); + let mut ram_lane = TwistLaneOverTime::new_zero(t); + + for (row_idx, r) in exec.rows.iter().enumerate() { + if !r.active { + if r.prog_read.is_some() + || r.reg_read_lane0.is_some() + || r.reg_read_lane1.is_some() + || r.reg_write_lane0.is_some() + || !r.ram_events.is_empty() + || !r.shout_events.is_empty() + { + return Err(format!("trace extract: inactive row has events at cycle {}", r.cycle)); + } + continue; + } + + // PROG: exactly one read + let prog_read = r + .prog_read + .as_ref() + .ok_or_else(|| format!("trace extract: active row missing prog_read at cycle {}", r.cycle))?; + prog.has_read[row_idx] = true; + prog.ra[row_idx] = prog_read.addr; + prog.rv[row_idx] = prog_read.value; + + // REG: exactly one read per lane + let rs1 = r + .reg_read_lane0 + .as_ref() + .ok_or_else(|| format!("trace extract: missing REG lane0 read at cycle {}", r.cycle))?; + let rs2 = r + .reg_read_lane1 + .as_ref() + .ok_or_else(|| format!("trace extract: missing REG lane1 read at cycle {}", r.cycle))?; + + reg0.has_read[row_idx] = true; + reg0.ra[row_idx] = rs1.addr; + reg0.rv[row_idx] = rs1.value; + + reg1.has_read[row_idx] = true; + reg1.ra[row_idx] = rs2.addr; + reg1.rv[row_idx] = rs2.value; + + if let Some(wr) = &r.reg_write_lane0 { + if wr.addr == 0 { + return Err(format!("trace extract: unexpected x0 write at cycle {}", r.cycle)); + } + if wr.addr >= 32 { + return Err(format!( + "trace extract: reg write addr out of range at cycle {}: addr={}", + r.cycle, wr.addr + )); + } + let prev = regs[wr.addr as usize]; + regs[wr.addr as usize] = wr.value; + regs[0] = 0; + + reg0.has_write[row_idx] = true; + reg0.wa[row_idx] = wr.addr; + reg0.wv[row_idx] = wr.value; + reg0.inc_at_write_addr[row_idx] = F::from_u64(wr.value) - F::from_u64(prev); + } + + // RAM (fixed-lane MVP): at most 1 read + 1 write per row + let mut read: Option<(u64, u64)> = None; + let mut write: Option<(u64, u64)> = None; + for e in &r.ram_events { + if e.lane.is_some() { + return Err(format!( + "trace extract: unexpected RAM lane hint at cycle {}: lane={:?}", + r.cycle, e.lane + )); + } + match e.kind { + neo_vm_trace::TwistOpKind::Read => { + if read.is_some() { + return Err(format!("trace extract: multiple RAM reads at cycle {}", r.cycle)); + } + read = Some((e.addr, e.value)); + } + neo_vm_trace::TwistOpKind::Write => { + if write.is_some() { + return Err(format!("trace extract: multiple RAM writes at cycle {}", r.cycle)); + } + write = Some((e.addr, e.value)); + } + } + } + + let has_read = read.is_some(); + let has_write = write.is_some(); + ram_lane.has_read[row_idx] = has_read; + ram_lane.has_write[row_idx] = has_write; + + let (ra, rv) = match read { + Some((a, v)) => (a, Some(v)), + None => (0, None), + }; + let (wa, wv) = match write { + Some((a, v)) => (a, Some(v)), + None => (0, None), + }; + if has_read && has_write && ra != wa { + return Err(format!( + "trace extract: RAM read/write addr mismatch at cycle {}: ra={ra} wa={wa}", + r.cycle + )); + } + let addr = if has_read { ra } else { wa }; + + if ram_ell_addr < 64 && (addr >> ram_ell_addr) != 0 { + return Err(format!( + "trace extract: RAM addr out of range for ell_addr={ram_ell_addr} at cycle {}: addr={addr}", + r.cycle + )); + } + + if let Some(v) = rv { + let prev = ram.get(&addr).copied().unwrap_or(0); + if prev != v { + return Err(format!( + "trace extract: RAM read value mismatch at cycle {} addr={addr}: got={v} expected_prev={prev}", + r.cycle + )); + } + ram_lane.ra[row_idx] = addr; + ram_lane.rv[row_idx] = v; + } + + if let Some(v) = wv { + let prev = ram.get(&addr).copied().unwrap_or(0); + ram_lane.wa[row_idx] = addr; + ram_lane.wv[row_idx] = v; + ram_lane.inc_at_write_addr[row_idx] = F::from_u64(v) - F::from_u64(prev); + + if v == 0 { + ram.remove(&addr); + } else { + ram.insert(addr, v); + } + } + } + + Ok(TraceTwistLanesOverTime { + prog, + reg_lane0: reg0, + reg_lane1: reg1, + ram: ram_lane, + }) +} + +/// Extract fixed-lane Shout lanes over time (one lane per `shout_table_ids` entry). +/// +/// Policy: +/// - At most 1 Shout event per active row. +/// - Inactive rows must have no shout events. +pub fn extract_shout_lanes_over_time( + exec: &Rv32ExecTable, + shout_table_ids: &[u32], +) -> Result, String> { + let t = exec.rows.len(); + + let mut table_id_to_idx: HashMap = HashMap::new(); + for (idx, &id) in shout_table_ids.iter().enumerate() { + if table_id_to_idx.insert(id, idx).is_some() { + return Err(format!("trace extract: duplicate shout_table_id={id}")); + } + } + + let mut lanes: Vec = (0..shout_table_ids.len()) + .map(|_| ShoutLaneOverTime::new_zero(t)) + .collect(); + + for (row_idx, r) in exec.rows.iter().enumerate() { + if !r.active { + if !r.shout_events.is_empty() { + return Err(format!( + "trace extract: inactive row has Shout events at cycle {}", + r.cycle + )); + } + continue; + } + + match r.shout_events.as_slice() { + [] => {} + [ev] => { + let idx = table_id_to_idx + .get(&ev.shout_id.0) + .copied() + .ok_or_else(|| { + format!( + "trace extract: shout_id={} not provisioned (cycle {})", + ev.shout_id.0, r.cycle + ) + })?; + lanes[idx].has_lookup[row_idx] = true; + let mut key = ev.key; + if let Some(op) = RiscvShoutTables::new(/*xlen=*/ 32).id_to_opcode(ev.shout_id) { + // Canonicalize shift keys: RISC-V shifts use only the low 5 bits of `rhs`. + // This shrinks the key space and keeps trace/sidecar linkage stable across packed / bit-addressed encodings. + if matches!(op, RiscvOpcode::Sll | RiscvOpcode::Srl | RiscvOpcode::Sra) { + let (lhs, rhs) = uninterleave_bits(key as u128); + let rhs_masked = rhs & 0x1F; + key = interleave_bits(lhs, rhs_masked) as u64; + } + } + lanes[idx].key[row_idx] = key; + lanes[idx].value[row_idx] = ev.value as u64; + } + _ => { + return Err(format!( + "trace extract: multiple Shout events at cycle {} (fixed-lane policy supports 1)", + r.cycle + )); + } + } + } + + Ok(lanes) +} diff --git a/crates/neo-memory/src/riscv/trace/width_sidecar.rs b/crates/neo-memory/src/riscv/trace/width_sidecar.rs new file mode 100644 index 00000000..b60bf743 --- /dev/null +++ b/crates/neo-memory/src/riscv/trace/width_sidecar.rs @@ -0,0 +1,177 @@ +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +use crate::riscv::exec_table::Rv32ExecTable; + +/// Base lookup table id for width-column lookup families in shared-bus mode. +/// +/// Table id for width column `c` is `RV32_TRACE_WIDTH_LOOKUP_TABLE_BASE + c`. +pub const RV32_TRACE_WIDTH_LOOKUP_TABLE_BASE: u32 = 0x5256_5800; +/// Base address-group id for width lookup lanes. +pub const RV32_TRACE_WIDTH_ADDR_GROUP_BASE: u32 = 0x5256_5A00; + +#[derive(Clone, Debug)] +pub struct Rv32WidthSidecarLayout { + pub cols: usize, + pub ram_rv_q16: usize, + pub rs2_q16: usize, + pub ram_rv_low_bit: [usize; 16], + pub rs2_low_bit: [usize; 16], +} + +impl Rv32WidthSidecarLayout { + pub fn new() -> Self { + let mut next = 0usize; + let mut take = || { + let out = next; + next += 1; + out + }; + + let ram_rv_q16 = take(); + let rs2_q16 = take(); + + let ram_rv_b0 = take(); + let ram_rv_b1 = take(); + let ram_rv_b2 = take(); + let ram_rv_b3 = take(); + let ram_rv_b4 = take(); + let ram_rv_b5 = take(); + let ram_rv_b6 = take(); + let ram_rv_b7 = take(); + let ram_rv_b8 = take(); + let ram_rv_b9 = take(); + let ram_rv_b10 = take(); + let ram_rv_b11 = take(); + let ram_rv_b12 = take(); + let ram_rv_b13 = take(); + let ram_rv_b14 = take(); + let ram_rv_b15 = take(); + + let rs2_low_b0 = take(); + let rs2_low_b1 = take(); + let rs2_low_b2 = take(); + let rs2_low_b3 = take(); + let rs2_low_b4 = take(); + let rs2_low_b5 = take(); + let rs2_low_b6 = take(); + let rs2_low_b7 = take(); + let rs2_low_b8 = take(); + let rs2_low_b9 = take(); + let rs2_low_b10 = take(); + let rs2_low_b11 = take(); + let rs2_low_b12 = take(); + let rs2_low_b13 = take(); + let rs2_low_b14 = take(); + let rs2_low_b15 = take(); + + debug_assert_eq!(next, 34); + Self { + cols: next, + ram_rv_q16, + rs2_q16, + ram_rv_low_bit: [ + ram_rv_b0, ram_rv_b1, ram_rv_b2, ram_rv_b3, ram_rv_b4, ram_rv_b5, ram_rv_b6, ram_rv_b7, ram_rv_b8, + ram_rv_b9, ram_rv_b10, ram_rv_b11, ram_rv_b12, ram_rv_b13, ram_rv_b14, ram_rv_b15, + ], + rs2_low_bit: [ + rs2_low_b0, + rs2_low_b1, + rs2_low_b2, + rs2_low_b3, + rs2_low_b4, + rs2_low_b5, + rs2_low_b6, + rs2_low_b7, + rs2_low_b8, + rs2_low_b9, + rs2_low_b10, + rs2_low_b11, + rs2_low_b12, + rs2_low_b13, + rs2_low_b14, + rs2_low_b15, + ], + } + } +} + +#[inline] +pub fn rv32_width_lookup_backed_cols(layout: &Rv32WidthSidecarLayout) -> Vec { + (0..layout.cols).collect() +} + +#[inline] +pub const fn rv32_width_lookup_table_id_for_col(col: usize) -> u32 { + RV32_TRACE_WIDTH_LOOKUP_TABLE_BASE + col as u32 +} + +#[inline] +pub const fn rv32_is_width_lookup_table_id(table_id: u32) -> bool { + table_id >= RV32_TRACE_WIDTH_LOOKUP_TABLE_BASE && table_id < RV32_TRACE_WIDTH_LOOKUP_TABLE_BASE + 34 +} + +#[inline] +pub fn rv32_width_lookup_addr_group_for_table_id(table_id: u32) -> Option { + if !rv32_is_width_lookup_table_id(table_id) { + return None; + } + Some(RV32_TRACE_WIDTH_ADDR_GROUP_BASE) +} + +#[derive(Clone, Debug)] +pub struct Rv32WidthSidecarWitness { + pub t: usize, + pub cols: Vec>, +} + +impl Rv32WidthSidecarWitness { + pub fn new_zero(layout: &Rv32WidthSidecarLayout, t: usize) -> Self { + Self { + t, + cols: vec![vec![F::ZERO; t]; layout.cols], + } + } +} + +pub fn rv32_width_sidecar_witness_from_exec_table( + layout: &Rv32WidthSidecarLayout, + exec: &Rv32ExecTable, +) -> Rv32WidthSidecarWitness { + let cols = exec.to_columns(); + let t = cols.len(); + let mut wit = Rv32WidthSidecarWitness::new_zero(layout, t); + + for i in 0..t { + if !cols.active[i] { + continue; + } + + let rs2_val_u64 = cols.rs2_val[i]; + wit.cols[layout.rs2_q16][i] = F::from_u64(rs2_val_u64 >> 16); + for (k, &bit_col) in layout.rs2_low_bit.iter().enumerate() { + wit.cols[bit_col][i] = F::from_u64((rs2_val_u64 >> k) & 1); + } + } + + for (i, r) in exec.rows.iter().enumerate() { + if !r.active { + continue; + } + let mut read_value: Option = None; + for e in &r.ram_events { + if e.kind == neo_vm_trace::TwistOpKind::Read { + read_value = Some(e.value); + break; + } + } + if let Some(rv) = read_value { + wit.cols[layout.ram_rv_q16][i] = F::from_u64(rv >> 16); + for (k, &bit_col) in layout.ram_rv_low_bit.iter().enumerate() { + wit.cols[bit_col][i] = F::from_u64((rv >> k) & 1); + } + } + } + + wit +} diff --git a/crates/neo-memory/src/riscv/trace/witness.rs b/crates/neo-memory/src/riscv/trace/witness.rs new file mode 100644 index 00000000..4d3ac949 --- /dev/null +++ b/crates/neo-memory/src/riscv/trace/witness.rs @@ -0,0 +1,158 @@ +use neo_vm_trace::TwistOpKind; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +use crate::riscv::exec_table::Rv32ExecTable; +use crate::riscv::lookups::{uninterleave_bits, RiscvOpcode, RiscvShoutTables}; + +use super::layout::Rv32TraceLayout; + +#[inline] +fn sign_extend_to_u32(value: u32, bits: u32) -> u32 { + let shift = 32 - bits; + (((value << shift) as i32) >> shift) as u32 +} + +#[inline] +fn imm_i_from_word(instr_word: u32) -> u32 { + sign_extend_to_u32((instr_word >> 20) & 0x0fff, 12) +} + +#[derive(Clone, Debug)] +pub struct Rv32TraceWitness { + pub t: usize, + /// Column-major: `cols[col][row]`. + pub cols: Vec>, +} + +impl Rv32TraceWitness { + pub fn new_zero(layout: &Rv32TraceLayout, t: usize) -> Self { + Self { + t, + cols: vec![vec![F::ZERO; t]; layout.cols], + } + } + + pub fn from_exec_table(layout: &Rv32TraceLayout, exec: &Rv32ExecTable) -> Result { + let cols = exec.to_columns(); + let t = cols.len(); + let mut wit = Self::new_zero(layout, t); + + for i in 0..t { + wit.cols[layout.one][i] = F::ONE; + + // Control / fetch + wit.cols[layout.active][i] = if cols.active[i] { F::ONE } else { F::ZERO }; + wit.cols[layout.halted][i] = if cols.halted[i] { F::ONE } else { F::ZERO }; + wit.cols[layout.cycle][i] = F::from_u64(cols.cycle[i]); + wit.cols[layout.pc_before][i] = F::from_u64(cols.pc_before[i]); + wit.cols[layout.pc_after][i] = F::from_u64(cols.pc_after[i]); + wit.cols[layout.instr_word][i] = F::from_u64(cols.instr_word[i] as u64); + if !cols.active[i] { + // Inactive rows stay quiescent; WB/WP sidecars enforce these zeros. + continue; + } + + // REG view + wit.cols[layout.rs1_addr][i] = F::from_u64(cols.rs1_addr[i]); + wit.cols[layout.rs1_val][i] = F::from_u64(cols.rs1_val[i]); + wit.cols[layout.rs2_addr][i] = F::from_u64(cols.rs2_addr[i]); + wit.cols[layout.rs2_val][i] = F::from_u64(cols.rs2_val[i]); + // Keep rd_addr aligned with decoded instruction field. + // REG write enable is carried by decode lookup selectors, so on non-write rows + // this address is don't-care for bus semantics. + wit.cols[layout.rd_addr][i] = F::from_u64(cols.rd[i] as u64); + wit.cols[layout.rd_val][i] = F::from_u64(cols.rd_val[i]); + if cols.opcode[i] == 0x67 { + let rs1 = cols.rs1_val[i] as u32; + let imm_i = imm_i_from_word(cols.instr_word[i]); + let drop = rs1.wrapping_add(imm_i) & 1; + wit.cols[layout.jalr_drop_bit][i] = F::from_u64(drop as u64); + } + } + + // Normalize RAM events per row: at most one read + one write. + for (i, r) in exec.rows.iter().enumerate() { + if !r.active { + continue; + } + + let mut read: Option<(u64, u64)> = None; + let mut write: Option<(u64, u64)> = None; + for e in &r.ram_events { + match e.kind { + TwistOpKind::Read => { + if read.is_some() { + return Err(format!("multiple RAM reads in one cycle={}", r.cycle)); + } + read = Some((e.addr, e.value)); + } + TwistOpKind::Write => { + if write.is_some() { + return Err(format!("multiple RAM writes in one cycle={}", r.cycle)); + } + write = Some((e.addr, e.value)); + } + } + } + + match (read, write) { + (Some((ra, rv)), Some((wa, wv))) => { + if ra != wa { + return Err(format!( + "RAM read/write addr mismatch in one cycle {}: ra={:#x} wa={:#x}", + r.cycle, ra, wa + )); + } + wit.cols[layout.ram_addr][i] = F::from_u64(ra); + wit.cols[layout.ram_rv][i] = F::from_u64(rv); + wit.cols[layout.ram_wv][i] = F::from_u64(wv); + } + (Some((ra, rv)), None) => { + wit.cols[layout.ram_addr][i] = F::from_u64(ra); + wit.cols[layout.ram_rv][i] = F::from_u64(rv); + } + (None, Some((wa, wv))) => { + wit.cols[layout.ram_addr][i] = F::from_u64(wa); + wit.cols[layout.ram_wv][i] = F::from_u64(wv); + } + (None, None) => {} + } + } + + // Normalize fixed-lane Shout view for the main trace. + // + // Shared-bus mode may carry auxiliary lookup families in addition to + // opcode-backed Shout events. The fixed-lane CPU shout glue must only + // bind to canonical RV32 opcode tables. + let shout_tables = RiscvShoutTables::new(/*xlen=*/ 32); + for (i, r) in exec.rows.iter().enumerate() { + if !r.active { + continue; + } + let primary = r + .shout_events + .iter() + .find(|ev| shout_tables.id_to_opcode(ev.shout_id).is_some()); + + if let Some(ev) = primary { + wit.cols[layout.shout_has_lookup][i] = F::ONE; + wit.cols[layout.shout_val][i] = F::from_u64(ev.value); + let (lhs, rhs) = uninterleave_bits(ev.key as u128); + wit.cols[layout.shout_lhs][i] = F::from_u64(lhs); + // Canonicalize shift keys: RISC-V shifts use only the low 5 bits of `rhs`. + let rhs = if let Some(op) = shout_tables.id_to_opcode(ev.shout_id) { + if matches!(op, RiscvOpcode::Sll | RiscvOpcode::Srl | RiscvOpcode::Sra) { + rhs & 0x1F + } else { + rhs + } + } else { + rhs + }; + wit.cols[layout.shout_rhs][i] = F::from_u64(rhs); + } + } + Ok(wit) + } +} diff --git a/crates/neo-memory/src/shout.rs b/crates/neo-memory/src/shout.rs index 77c20107..527a12f1 100644 --- a/crates/neo-memory/src/shout.rs +++ b/crates/neo-memory/src/shout.rs @@ -153,6 +153,16 @@ pub fn check_shout_semantics( let out = compute_op(*opcode, rs1, rs2, *xlen); F::from_u64(out) } + Some(LutTableSpec::RiscvOpcodePacked { .. }) => { + return Err(PiCcsError::InvalidInput( + "Shout semantic checker only supports bit-addressed LUT witnesses".into(), + )); + } + Some(LutTableSpec::RiscvOpcodeEventTablePacked { .. }) => { + return Err(PiCcsError::InvalidInput( + "Shout semantic checker only supports bit-addressed LUT witnesses".into(), + )); + } Some(LutTableSpec::IdentityU32) => F::from_u64(addrs[j]), None => { if addr >= inst.table.len() { diff --git a/crates/neo-memory/src/sparse_matrix.rs b/crates/neo-memory/src/sparse_matrix.rs new file mode 100644 index 00000000..a4e859cd --- /dev/null +++ b/crates/neo-memory/src/sparse_matrix.rs @@ -0,0 +1,292 @@ +//! Sparse matrix representations for multilinear polynomials. +//! +//! This module is a small, self-contained primitive inspired by Jolt's `read_write_matrix` +//! data structures (`external/jolt/.../read_write_matrix`). The core idea is: +//! +//! - A matrix `M(row, col)` over Boolean hypercubes is represented by its non-zero entries. +//! - Binding (folding) one variable corresponds to combining pairs of indices (2k,2k+1) with +//! weights `(1-r, r)`, exactly like multilinear extension folding. +//! +//! The implementation here is intentionally simple (O(nnz) per bind) and is meant as a +//! correctness-first building block for future "Jolt-like" sparse arguments. + +use p3_field::PrimeCharacteristicRing; + +/// A sparse matrix entry at `(row, col)` with non-zero coefficient `value`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SparseMatEntry { + pub row: u64, + pub col: u64, + pub value: R, +} + +/// Sparse matrix over a `2^ell_row × 2^ell_col` Boolean grid. +/// +/// Invariant: +/// - dimensions are tracked by `(ell_row, ell_col)` (no `1 << ell` allocation) +/// - `entries` are strictly increasing by `(row, col)` +/// - all stored values are non-zero +#[derive(Clone, Debug)] +pub struct SparseMat { + ell_row: usize, + ell_col: usize, + entries: Vec>, +} + +impl SparseMat +where + R: PrimeCharacteristicRing + Copy + PartialEq, +{ + pub fn new(ell_row: usize, ell_col: usize) -> Self { + Self { + ell_row, + ell_col, + entries: Vec::new(), + } + } + + pub fn ell_row(&self) -> usize { + self.ell_row + } + + pub fn ell_col(&self) -> usize { + self.ell_col + } + + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + pub fn entries(&self) -> &[SparseMatEntry] { + &self.entries + } + + /// Build a sparse matrix from arbitrary entries. + /// + /// Entries are normalized by: + /// - dropping zeros, + /// - sorting by `(row, col)`, + /// - combining duplicates by summation. + pub fn from_entries(ell_row: usize, ell_col: usize, mut entries: Vec>) -> Self { + entries.retain(|e| e.value != R::ZERO); + entries.sort_by_key(|e| (e.row, e.col)); + + let mut out: Vec> = Vec::with_capacity(entries.len()); + for e in entries { + debug_assert!(ell_row >= 64 || (e.row >> ell_row) == 0, "entry row out of range"); + debug_assert!(ell_col >= 64 || (e.col >> ell_col) == 0, "entry col out of range"); + if let Some(last) = out.last_mut() { + if last.row == e.row && last.col == e.col { + last.value += e.value; + if last.value == R::ZERO { + out.pop(); + } + continue; + } + } + out.push(e); + } + + #[cfg(debug_assertions)] + for w in out.windows(2) { + let a = &w[0]; + let b = &w[1]; + debug_assert!( + (a.row, a.col) < (b.row, b.col), + "SparseMat entries must be strictly increasing" + ); + debug_assert!(a.value != R::ZERO && b.value != R::ZERO, "SparseMat stores zeros"); + } + + Self { + ell_row, + ell_col, + entries: out, + } + } + + /// Get the coefficient at `(row, col)`, returning `0` if not present. + pub fn get(&self, row: u64, col: u64) -> R { + if (self.ell_row < 64 && (row >> self.ell_row) != 0) || (self.ell_col < 64 && (col >> self.ell_col) != 0) { + return R::ZERO; + } + match self + .entries + .binary_search_by_key(&(row, col), |e| (e.row, e.col)) + { + Ok(pos) => self.entries[pos].value, + Err(_) => R::ZERO, + } + } + + /// One multilinear folding round on the least-significant row bit. + /// + /// For each parent row `p` and column `c`: + /// `out[p, c] = in[2p, c] * (1-r) + in[2p+1, c] * r` + pub fn fold_row_round_in_place(&mut self, r: R) { + debug_assert!(self.ell_row > 0, "cannot fold when ell_row == 0"); + + let one_minus_r = R::ONE - r; + let mut out: Vec> = Vec::with_capacity(self.entries.len()); + + for e in self.entries.iter().copied() { + let parent_row = e.row >> 1; + let scaled = if (e.row & 1) == 0 { + e.value * one_minus_r + } else { + e.value * r + }; + if scaled == R::ZERO { + continue; + } + + if let Some(last) = out.last_mut() { + if last.row == parent_row && last.col == e.col { + last.value += scaled; + if last.value == R::ZERO { + out.pop(); + } + continue; + } + } + + out.push(SparseMatEntry { + row: parent_row, + col: e.col, + value: scaled, + }); + } + + self.entries = out; + self.ell_row -= 1; + } + + /// One multilinear folding round on the least-significant column bit. + /// + /// For each row `r` and parent column `p`: + /// `out[r, p] = in[r, 2p] * (1-r) + in[r, 2p+1] * r` + pub fn fold_col_round_in_place(&mut self, r: R) { + debug_assert!(self.ell_col > 0, "cannot fold when ell_col == 0"); + + let one_minus_r = R::ONE - r; + let mut out: Vec> = Vec::with_capacity(self.entries.len()); + + // The input is strictly sorted by `(row, col)`. Since `parent_col = col >> 1` is + // non-decreasing in `col`, the folded entries remain sorted by `(row, parent_col)`. + // Duplicates can occur only when both `2p` and `2p+1` are present; those are adjacent + // in the input order, so we can merge on the fly without a global re-sort. + for e in self.entries.iter().copied() { + let parent_col = e.col >> 1; + let scaled = if (e.col & 1) == 0 { + e.value * one_minus_r + } else { + e.value * r + }; + if scaled == R::ZERO { + continue; + } + + if let Some(last) = out.last_mut() { + if last.row == e.row && last.col == parent_col { + last.value += scaled; + if last.value == R::ZERO { + out.pop(); + } + continue; + } + } + + out.push(SparseMatEntry { + row: e.row, + col: parent_col, + value: scaled, + }); + } + + self.entries = out; + self.ell_col -= 1; + } + + /// Evaluate the multilinear extension of the sparse matrix at `(r_row, r_col)` + /// by repeated folding. + pub fn mle_eval_by_folding(&self, r_row: &[R], r_col: &[R]) -> Result { + if self.ell_row != r_row.len() { + return Err(format!( + "SparseMat: ell_row={} does not match r_row.len()={}", + self.ell_row, + r_row.len() + )); + } + if self.ell_col != r_col.len() { + return Err(format!( + "SparseMat: ell_col={} does not match r_col.len()={}", + self.ell_col, + r_col.len() + )); + } + + let mut cur = self.clone(); + for &r in r_row { + cur.fold_row_round_in_place(r); + } + for &r in r_col { + cur.fold_col_round_in_place(r); + } + if cur.ell_row != 0 || cur.ell_col != 0 { + return Err("SparseMat: folding did not reach ell_row=0, ell_col=0".into()); + } + Ok(cur.entries.first().map(|e| e.value).unwrap_or(R::ZERO)) + } + + /// Evaluate the multilinear extension of the sparse matrix at `(r_row, r_col)` by direct sparse summation. + /// + /// This computes: + /// `Σ_{(i,j)∈supp} M[i,j] · χ_{r_row}(i) · χ_{r_col}(j)`. + pub fn mle_eval_direct(&self, r_row: &[R], r_col: &[R]) -> Result { + if self.ell_row != r_row.len() { + return Err(format!( + "SparseMat: ell_row={} does not match r_row.len()={}", + self.ell_row, + r_row.len() + )); + } + if self.ell_col != r_col.len() { + return Err(format!( + "SparseMat: ell_col={} does not match r_col.len()={}", + self.ell_col, + r_col.len() + )); + } + + #[inline] + fn chi_at_u64_index(r: &[R], idx: u64) -> R { + let mut acc = R::ONE; + for (b, &rb) in r.iter().enumerate() { + let bit = if b < 64 { (idx >> b) & 1 } else { 0 }; + acc *= if bit == 1 { rb } else { R::ONE - rb }; + } + acc + } + + let mut acc = R::ZERO; + for e in self.entries.iter().copied() { + if self.ell_row < 64 && (e.row >> self.ell_row) != 0 { + return Err("SparseMat: entry row out of range for ell_row".into()); + } + if self.ell_col < 64 && (e.col >> self.ell_col) != 0 { + return Err("SparseMat: entry col out of range for ell_col".into()); + } + + let chi_row = chi_at_u64_index(r_row, e.row); + if chi_row == R::ZERO { + continue; + } + let chi_col = chi_at_u64_index(r_col, e.col); + if chi_col == R::ZERO { + continue; + } + acc += e.value * chi_row * chi_col; + } + Ok(acc) + } +} diff --git a/crates/neo-memory/src/twist_oracle.rs b/crates/neo-memory/src/twist_oracle.rs index 0a06c741..20e4a17a 100644 --- a/crates/neo-memory/src/twist_oracle.rs +++ b/crates/neo-memory/src/twist_oracle.rs @@ -18,6 +18,7 @@ use crate::mle::{eq_single, lt_eval}; use crate::sparse_time::SparseIdxVec; use neo_math::K; use neo_reductions::sumcheck::RoundOracle; +use p3_field::Field; use p3_field::PrimeCharacteristicRing; macro_rules! impl_round_oracle_via_core { @@ -305,6 +306,4936 @@ impl RoundOracle for ShoutValueOracleSparse { } } +// ============================================================================ +// Packed-key RV32 ADD Shout (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 packed ADD correctness: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) + rhs(t) - val(t) - carry(t)·2^32) +/// +/// This is a "no width bloat" alternative to the 64-bit addr-bit Shout encoding for `ADD`: +/// instead of committing to 64 key bits, we commit to packed `lhs/rhs` plus a carry bit. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedAddOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + carry: SparseIdxVec, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedAddOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + carry: SparseIdxVec, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(carry.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + carry, + val, + degree_bound: 3, + } + } +} + +impl RoundOracle for Rv32PackedAddOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let two32 = K::from_u64(1u64 << 32); + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let carry = self.carry.singleton_value(); + let val = self.val.singleton_value(); + let expr = lhs + rhs - val - carry * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let two32 = K::from_u64(1u64 << 32); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let carry0 = self.carry.get(child0); + let carry1 = self.carry.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let expr0 = lhs0 + rhs0 - val0 - carry0 * two32; + let expr1 = lhs1 + rhs1 - val1 - carry1 * two32; + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let expr_x = interp(expr0, expr1, x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.carry.fold_round_in_place(r); + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed SUB correctness: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - rhs(t) - val(t) + borrow(t)·2^32) +/// +/// This is a "no width bloat" alternative to the 64-bit addr-bit Shout encoding for `SUB`: +/// instead of committing to 64 key bits, we commit to packed `lhs/rhs` plus a borrow bit. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedSubOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + borrow: SparseIdxVec, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedSubOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + borrow: SparseIdxVec, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(borrow.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + borrow, + val, + degree_bound: 3, + } + } +} + +impl RoundOracle for Rv32PackedSubOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let two32 = K::from_u64(1u64 << 32); + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let borrow = self.borrow.singleton_value(); + let val = self.val.singleton_value(); + let expr = lhs - rhs - val + borrow * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let two32 = K::from_u64(1u64 << 32); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let borrow0 = self.borrow.get(child0); + let borrow1 = self.borrow.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let expr0 = lhs0 - rhs0 - val0 + borrow0 * two32; + let expr1 = lhs1 - rhs1 - val1 + borrow1 * two32; + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let expr_x = interp(expr0, expr1, x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.borrow.fold_round_in_place(r); + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 MUL Shout (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 packed MUL correctness: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t)·rhs(t) - val(t) - carry(t)·2^32) +/// +/// Where: +/// - `carry(t)` is the high 32 bits of the 64-bit product `lhs·rhs`, encoded as 32 Boolean columns, +/// - `val(t)` is the low 32 bits (the `MUL` result). +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedMulOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + carry_bits: Vec>, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedMulOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + carry_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(carry_bits.len(), 32); + for (i, b) in carry_bits.iter().enumerate() { + debug_assert_eq!( + b.len(), + 1usize << ell_n, + "carry_bits[{i}] length must match time domain" + ); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + carry_bits, + val, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - lhs(t)·rhs(t): degree 2 + // - val(t), carry(t): multilinear (degree 1) + // ⇒ total degree ≤ 1 + 1 + 2 = 4 + degree_bound: 4, + } + } +} + +impl RoundOracle for Rv32PackedMulOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let val = self.val.singleton_value(); + + let mut carry = K::ZERO; + for (i, b) in self.carry_bits.iter().enumerate() { + carry += b.singleton_value() * K::from_u64(1u64 << i); + } + + let expr = lhs * rhs - val - carry * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + // Pre-fetch carry bit endpoints for this pair to avoid repeated sparse lookups per eval point. + let mut c0s: [K; 32] = [K::ZERO; 32]; + let mut c1s: [K; 32] = [K::ZERO; 32]; + for (i, b) in self.carry_bits.iter().enumerate() { + c0s[i] = b.get(child0); + c1s[i] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let rhs_x = interp(rhs0, rhs1, x); + let val_x = interp(val0, val1, x); + + let mut carry_x = K::ZERO; + for j in 0..32 { + let c_x = interp(c0s[j], c1s[j], x); + carry_x += c_x * K::from_u64(1u64 << j); + } + + let expr_x = lhs_x * rhs_x - val_x - carry_x * two32; + if expr_x == K::ZERO { + continue; + } + + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + for b in self.carry_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 MULHU Shout (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 packed MULHU correctness (unsigned high 32 bits): +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t)·rhs(t) - lo(t) - val(t)·2^32) +/// +/// Where: +/// - `lo(t)` is the low 32 bits of the 64-bit product `lhs·rhs`, encoded as 32 Boolean columns, +/// - `val(t)` is the upper 32 bits (the `MULHU` result). +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedMulhuOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lo_bits: Vec>, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedMulhuOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lo_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(lo_bits.len(), 32); + for (i, b) in lo_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "lo_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + lo_bits, + val, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - lhs(t)·rhs(t): degree 2 + // - lo(t), val(t): multilinear (degree 1) + // ⇒ total degree ≤ 1 + 1 + 2 = 4 + degree_bound: 4, + } + } +} + +impl RoundOracle for Rv32PackedMulhuOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let val = self.val.singleton_value(); + + let mut lo = K::ZERO; + for (i, b) in self.lo_bits.iter().enumerate() { + lo += b.singleton_value() * K::from_u64(1u64 << i); + } + + let expr = lhs * rhs - lo - val * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + // Pre-fetch lo bit endpoints for this pair to avoid repeated sparse lookups per eval point. + let mut lo0s: [K; 32] = [K::ZERO; 32]; + let mut lo1s: [K; 32] = [K::ZERO; 32]; + for (i, b) in self.lo_bits.iter().enumerate() { + lo0s[i] = b.get(child0); + lo1s[i] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let rhs_x = interp(rhs0, rhs1, x); + let val_x = interp(val0, val1, x); + + let mut lo_x = K::ZERO; + for j in 0..32 { + let b_x = interp(lo0s[j], lo1s[j], x); + lo_x += b_x * K::from_u64(1u64 << j); + } + + let expr_x = lhs_x * rhs_x - lo_x - val_x * two32; + if expr_x == K::ZERO { + continue; + } + + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + for b in self.lo_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 MULH / MULHSU helpers (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 unsigned product decomposition: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t)·rhs(t) - lo(t) - hi(t)·2^32) +/// +/// Where: +/// - `lo(t)` is the low 32 bits of the 64-bit product, encoded as 32 Boolean columns, +/// - `hi(t)` is a witness column intended to be the upper 32 bits. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedMulHiOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lo_bits: Vec>, + hi: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedMulHiOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lo_bits: Vec>, + hi: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(hi.len(), 1usize << ell_n); + debug_assert_eq!(lo_bits.len(), 32); + for (i, b) in lo_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "lo_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + lo_bits, + hi, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - lhs(t)·rhs(t): degree 2 + // ⇒ total degree ≤ 1 + 1 + 2 = 4 + degree_bound: 4, + } + } +} + +impl RoundOracle for Rv32PackedMulHiOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let hi = self.hi.singleton_value(); + + let mut lo = K::ZERO; + for (i, b) in self.lo_bits.iter().enumerate() { + lo += b.singleton_value() * K::from_u64(1u64 << i); + } + + let expr = lhs * rhs - lo - hi * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let hi0 = self.hi.get(child0); + let hi1 = self.hi.get(child1); + + // Pre-fetch lo bit endpoints for this pair to avoid repeated sparse lookups per eval point. + let mut lo0s: [K; 32] = [K::ZERO; 32]; + let mut lo1s: [K; 32] = [K::ZERO; 32]; + for (i, b) in self.lo_bits.iter().enumerate() { + lo0s[i] = b.get(child0); + lo1s[i] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let rhs_x = interp(rhs0, rhs1, x); + let hi_x = interp(hi0, hi1, x); + + let mut lo_x = K::ZERO; + for j in 0..32 { + let b_x = interp(lo0s[j], lo1s[j], x); + lo_x += b_x * K::from_u64(1u64 << j); + } + + let expr_x = lhs_x * rhs_x - lo_x - hi_x * two32; + if expr_x == K::ZERO { + continue; + } + + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + for b in self.lo_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.hi.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed MULH signed correction: +/// Σ_t χ(t)·has_lookup(t)·(w0·(hi - s1·rhs - s2·lhs + k·2^32 - val) + w1·k(k-1)(k-2)) +/// +/// Where: +/// - `hi` is the upper 32 bits of the unsigned product `lhs·rhs`, +/// - `s1`, `s2` are witness sign bits (msb of lhs/rhs), +/// - `k ∈ {0,1,2}` accounts for mod-2^32 normalization. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedMulhAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, + hi: SparseIdxVec, + k: SparseIdxVec, + val: SparseIdxVec, + weights: [K; 2], + degree_bound: usize, +} + +impl Rv32PackedMulhAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, + hi: SparseIdxVec, + k: SparseIdxVec, + val: SparseIdxVec, + weights: [K; 2], + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(rhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(hi.len(), 1usize << ell_n); + debug_assert_eq!(k.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + lhs_sign, + rhs_sign, + hi, + k, + val, + weights, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - eq expr: degree 2 (sign·rhs) + // - range poly: degree 3 + // ⇒ total degree ≤ 1 + 1 + 3 = 5 + degree_bound: 5, + } + } +} + +impl RoundOracle for Rv32PackedMulhAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + let w0 = self.weights[0]; + let w1 = self.weights[1]; + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let lhs_sign = self.lhs_sign.singleton_value(); + let rhs_sign = self.rhs_sign.singleton_value(); + let hi = self.hi.singleton_value(); + let k = self.k.singleton_value(); + let val = self.val.singleton_value(); + + let eq_expr = hi - lhs_sign * rhs - rhs_sign * lhs + k * two32 - val; + let range = k * (k - K::ONE) * (k - K::from_u64(2)); + let expr = w0 * eq_expr + w1 * range; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let lhs_sign0 = self.lhs_sign.get(child0); + let lhs_sign1 = self.lhs_sign.get(child1); + let rhs_sign0 = self.rhs_sign.get(child0); + let rhs_sign1 = self.rhs_sign.get(child1); + let hi0 = self.hi.get(child0); + let hi1 = self.hi.get(child1); + let k0 = self.k.get(child0); + let k1 = self.k.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let rhs_x = interp(rhs0, rhs1, x); + let lhs_sign_x = interp(lhs_sign0, lhs_sign1, x); + let rhs_sign_x = interp(rhs_sign0, rhs_sign1, x); + let hi_x = interp(hi0, hi1, x); + let k_x = interp(k0, k1, x); + let val_x = interp(val0, val1, x); + + let eq_expr = hi_x - lhs_sign_x * rhs_x - rhs_sign_x * lhs_x + k_x * two32 - val_x; + let range = k_x * (k_x - K::ONE) * (k_x - K::from_u64(2)); + let expr_x = w0 * eq_expr + w1 * range; + if expr_x == K::ZERO { + continue; + } + + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.lhs_sign.fold_round_in_place(r); + self.rhs_sign.fold_round_in_place(r); + self.hi.fold_round_in_place(r); + self.k.fold_round_in_place(r); + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed MULHSU signed correction: +/// Σ_t χ(t)·has_lookup(t)·(hi - s·rhs - val + b·2^32) +/// +/// Where: +/// - `hi` is the upper 32 bits of the unsigned product `lhs·rhs`, +/// - `s` is the witness sign bit of `lhs`, +/// - `b ∈ {0,1}` is a borrow bit for mod-2^32 normalization. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedMulhsuAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lhs_sign: SparseIdxVec, + hi: SparseIdxVec, + borrow: SparseIdxVec, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedMulhsuAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lhs_sign: SparseIdxVec, + hi: SparseIdxVec, + borrow: SparseIdxVec, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(hi.len(), 1usize << ell_n); + debug_assert_eq!(borrow.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + lhs_sign, + hi, + borrow, + val, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - eq expr: degree 2 (sign·rhs) + // ⇒ total degree ≤ 1 + 1 + 2 = 4 + degree_bound: 4, + } + } +} + +impl RoundOracle for Rv32PackedMulhsuAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let rhs = self.rhs.singleton_value(); + let lhs_sign = self.lhs_sign.singleton_value(); + let hi = self.hi.singleton_value(); + let borrow = self.borrow.singleton_value(); + let val = self.val.singleton_value(); + let expr = hi - lhs_sign * rhs - val + borrow * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let lhs_sign0 = self.lhs_sign.get(child0); + let lhs_sign1 = self.lhs_sign.get(child1); + let hi0 = self.hi.get(child0); + let hi1 = self.hi.get(child1); + let borrow0 = self.borrow.get(child0); + let borrow1 = self.borrow.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let rhs_x = interp(rhs0, rhs1, x); + let lhs_sign_x = interp(lhs_sign0, lhs_sign1, x); + let hi_x = interp(hi0, hi1, x); + let borrow_x = interp(borrow0, borrow1, x); + let val_x = interp(val0, val1, x); + + let expr_x = hi_x - lhs_sign_x * rhs_x - val_x + borrow_x * two32; + if expr_x == K::ZERO { + continue; + } + + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.lhs_sign.fold_round_in_place(r); + self.hi.fold_round_in_place(r); + self.borrow.fold_round_in_place(r); + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed EQ correctness (Ajtai-representable): +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (val(t) - Π_i (1 - diff_bit_i(t))) +/// +/// Where `diff_bits` encode `diff = lhs - rhs (mod 2^32)` (checked by the adapter oracle). +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedEqOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + diff_bits: Vec>, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedEqOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + diff_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(diff_bits.len(), 32); + for (i, b) in diff_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + diff_bits, + val, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - Π_i (1 - diff_bit_i(t)): product of 32 linear terms (degree 32) + // ⇒ total degree ≤ 1 + 1 + 32 = 34 + degree_bound: 34, + } + } +} + +impl RoundOracle for Rv32PackedEqOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let val = self.val.singleton_value(); + let mut prod = K::ONE; + for b in self.diff_bits.iter() { + let bit = b.singleton_value(); + prod *= K::ONE - bit; + } + let expr = val - prod; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let val_x = interp(val0, val1, x); + let mut prod_x = K::ONE; + for b in self.diff_bits.iter() { + let b0 = b.get(child0); + let b1 = b.get(child1); + let bit_x = interp(b0, b1, x); + prod_x *= K::ONE - bit_x; + } + let expr_x = val_x - prod_x; + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + for b in self.diff_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed EQ/NEQ diff decomposition check: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - rhs(t) - diff(t) + borrow(t)·2^32) +/// +/// Where: +/// - `borrow(t)` is the SUB borrow bit (1 iff lhs < rhs), +/// - `diff(t) = Σ_i 2^i · diff_bit_i(t)` is the u32 wraparound difference `lhs - rhs (mod 2^32)`. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedEqAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + borrow: SparseIdxVec, + diff_bits: Vec>, + degree_bound: usize, +} + +impl Rv32PackedEqAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + borrow: SparseIdxVec, + diff_bits: Vec>, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(borrow.len(), 1usize << ell_n); + debug_assert_eq!(diff_bits.len(), 32); + for (i, b) in diff_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + borrow, + diff_bits, + degree_bound: 3, + } + } +} + +impl RoundOracle for Rv32PackedEqAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let two32 = K::from_u64(1u64 << 32); + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let borrow = self.borrow.singleton_value(); + let mut diff = K::ZERO; + for (i, b) in self.diff_bits.iter().enumerate() { + let bit = b.singleton_value(); + diff += bit * K::from_u64(1u64 << i); + } + let expr = lhs - rhs - diff + borrow * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let two32 = K::from_u64(1u64 << 32); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let borrow0 = self.borrow.get(child0); + let borrow1 = self.borrow.get(child1); + let mut diff0 = K::ZERO; + let mut diff1 = K::ZERO; + for (i, b) in self.diff_bits.iter().enumerate() { + let pow = K::from_u64(1u64 << i); + diff0 += b.get(child0) * pow; + diff1 += b.get(child1) * pow; + } + let expr0 = lhs0 - rhs0 - diff0 + borrow0 * two32; + let expr1 = lhs1 - rhs1 - diff1 + borrow1 * two32; + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let expr_x = interp(expr0, expr1, x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.borrow.fold_round_in_place(r); + for b in self.diff_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed NEQ correctness (Ajtai-representable): +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (val(t) - (1 - Π_i (1 - diff_bit_i(t)))) +/// +/// Where `diff_bits` encode `diff = lhs - rhs (mod 2^32)` (checked by the adapter oracle). +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedNeqOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + diff_bits: Vec>, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedNeqOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + diff_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(diff_bits.len(), 32); + for (i, b) in diff_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + diff_bits, + val, + // Same degree bound as EQ: 1 + 1 + 32 = 34. + degree_bound: 34, + } + } +} + +impl RoundOracle for Rv32PackedNeqOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let val = self.val.singleton_value(); + let mut prod = K::ONE; + for b in self.diff_bits.iter() { + let bit = b.singleton_value(); + prod *= K::ONE - bit; + } + let expr = val + prod - K::ONE; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let val_x = interp(val0, val1, x); + let mut prod_x = K::ONE; + for b in self.diff_bits.iter() { + let b0 = b.get(child0); + let b1 = b.get(child1); + let bit_x = interp(b0, b1, x); + prod_x *= K::ONE - bit_x; + } + let expr_x = val_x + prod_x - K::ONE; + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + for b in self.diff_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed NEQ diff decomposition check: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - rhs(t) - diff(t) + borrow(t)·2^32) +/// +/// See [`Rv32PackedEqAdapterOracleSparseTime`] for details; this is identical but kept as a +/// distinct type for call-site clarity. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedNeqAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + borrow: SparseIdxVec, + diff_bits: Vec>, + degree_bound: usize, +} + +impl Rv32PackedNeqAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + borrow: SparseIdxVec, + diff_bits: Vec>, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(borrow.len(), 1usize << ell_n); + debug_assert_eq!(diff_bits.len(), 32); + for (i, b) in diff_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + borrow, + diff_bits, + degree_bound: 3, + } + } +} + +impl RoundOracle for Rv32PackedNeqAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let two32 = K::from_u64(1u64 << 32); + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let borrow = self.borrow.singleton_value(); + let mut diff = K::ZERO; + for (i, b) in self.diff_bits.iter().enumerate() { + let bit = b.singleton_value(); + diff += bit * K::from_u64(1u64 << i); + } + let expr = lhs - rhs - diff + borrow * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let two32 = K::from_u64(1u64 << 32); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let borrow0 = self.borrow.get(child0); + let borrow1 = self.borrow.get(child1); + let mut diff0 = K::ZERO; + let mut diff1 = K::ZERO; + for (i, b) in self.diff_bits.iter().enumerate() { + let pow = K::from_u64(1u64 << i); + diff0 += b.get(child0) * pow; + diff1 += b.get(child1) * pow; + } + let expr0 = lhs0 - rhs0 - diff0 + borrow0 * two32; + let expr1 = lhs1 - rhs1 - diff1 + borrow1 * two32; + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let expr_x = interp(expr0, expr1, x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.borrow.fold_round_in_place(r); + for b in self.diff_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 SLTU Shout (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 packed SLT (signed less-than) correctness. +/// +/// We reduce signed comparison to unsigned comparison by XOR-biasing both operands with `2^31` +/// (flip the sign bit). Let: +/// lhs_b = lhs ⊕ 2^31, rhs_b = rhs ⊕ 2^31 +/// then `(lhs as i32) < (rhs as i32)` iff `lhs_b < rhs_b` as unsigned. +/// +/// With witness bits `lhs_sign`, `rhs_sign` (intended as the msb of lhs/rhs), we implement the +/// XOR-biasing arithmetically: +/// x ⊕ 2^31 = x + (1 - 2·msb(x))·2^31. +/// +/// The correctness constraint is: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs_b(t) - rhs_b(t) - diff(t) + out(t)·2^32) +/// +/// Where `out(t)` is the SLT result bit (1 iff lhs < rhs in signed order, else 0) and `diff(t)` +/// is the u32 difference `lhs_b - rhs_b (mod 2^32)`. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedSltOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, + diff: SparseIdxVec, + out: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedSltOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, + diff: SparseIdxVec, + out: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(rhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(diff.len(), 1usize << ell_n); + debug_assert_eq!(out.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + lhs_sign, + rhs_sign, + diff, + out, + degree_bound: 3, + } + } +} + +impl RoundOracle for Rv32PackedSltOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two31 = K::from_u64(1u64 << 31); + let two32 = K::from_u64(1u64 << 32); + let two = K::from_u64(2); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let lhs_sign = self.lhs_sign.singleton_value(); + let rhs_sign = self.rhs_sign.singleton_value(); + let diff = self.diff.singleton_value(); + let out = self.out.singleton_value(); + + let lhs_b = lhs + (K::ONE - two * lhs_sign) * two31; + let rhs_b = rhs + (K::ONE - two * rhs_sign) * two31; + let expr = lhs_b - rhs_b - diff + out * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let lhs_sign0 = self.lhs_sign.get(child0); + let lhs_sign1 = self.lhs_sign.get(child1); + let rhs_sign0 = self.rhs_sign.get(child0); + let rhs_sign1 = self.rhs_sign.get(child1); + let diff0 = self.diff.get(child0); + let diff1 = self.diff.get(child1); + let out0 = self.out.get(child0); + let out1 = self.out.get(child1); + + let lhs_b0 = lhs0 + (K::ONE - two * lhs_sign0) * two31; + let lhs_b1 = lhs1 + (K::ONE - two * lhs_sign1) * two31; + let rhs_b0 = rhs0 + (K::ONE - two * rhs_sign0) * two31; + let rhs_b1 = rhs1 + (K::ONE - two * rhs_sign1) * two31; + + let expr0 = lhs_b0 - rhs_b0 - diff0 + out0 * two32; + let expr1 = lhs_b1 - rhs_b1 - diff1 + out1 * two32; + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let expr_x = interp(expr0, expr1, x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.lhs_sign.fold_round_in_place(r); + self.rhs_sign.fold_round_in_place(r); + self.diff.fold_round_in_place(r); + self.out.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed SLTU correctness: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - rhs(t) - diff(t) + out(t)·2^32) +/// +/// Where `out(t)` is the SLTU result bit (1 iff lhs < rhs, else 0) and `diff(t)` is the u32 +/// difference `lhs - rhs (mod 2^32)`. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedSltuOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + diff: SparseIdxVec, + out: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedSltuOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + diff: SparseIdxVec, + out: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(diff.len(), 1usize << ell_n); + debug_assert_eq!(out.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + diff, + out, + degree_bound: 3, + } + } +} + +impl RoundOracle for Rv32PackedSltuOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let two32 = K::from_u64(1u64 << 32); + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let diff = self.diff.singleton_value(); + let out = self.out.singleton_value(); + let expr = lhs - rhs - diff + out * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let two32 = K::from_u64(1u64 << 32); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let diff0 = self.diff.get(child0); + let diff1 = self.diff.get(child1); + let out0 = self.out.get(child0); + let out1 = self.out.get(child1); + + let expr0 = lhs0 - rhs0 - diff0 + out0 * two32; + let expr1 = lhs1 - rhs1 - diff1 + out1 * two32; + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let expr_x = interp(expr0, expr1, x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.diff.fold_round_in_place(r); + self.out.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 SLL Shout (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 packed SLL correctness: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) · 2^{shamt(t)} - val(t) - carry(t)·2^32) +/// +/// Where: +/// - `shamt(t)` is the shift amount (0..31), encoded as 5 Boolean columns, +/// - `carry(t)` is the high part of `(lhs << shamt)` as a u32 (range-checked separately), +/// - `val(t)` is the low 32 bits of `(lhs << shamt)`. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedSllOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + shamt_bits: Vec>, + carry_bits: Vec>, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedSllOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + shamt_bits: Vec>, + carry_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(shamt_bits.len(), 5); + for (i, b) in shamt_bits.iter().enumerate() { + debug_assert_eq!( + b.len(), + 1usize << ell_n, + "shamt_bits[{i}] length must match time domain" + ); + } + debug_assert_eq!(carry_bits.len(), 32); + for (i, b) in carry_bits.iter().enumerate() { + debug_assert_eq!( + b.len(), + 1usize << ell_n, + "carry_bits[{i}] length must match time domain" + ); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + shamt_bits, + carry_bits, + val, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - lhs(t): multilinear (degree 1) + // - 2^{shamt(t)}: product of 5 linear terms in the shamt bits (degree 5) + // - val(t), carry(t): multilinear (degree 1) + // ⇒ total degree ≤ 1 + 1 + max(1+5, 1, 1) = 8 + degree_bound: 8, + } + } +} + +impl RoundOracle for Rv32PackedSllOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + let pow2_const: [K; 5] = [ + K::from_u64(2), + K::from_u64(4), + K::from_u64(16), + K::from_u64(256), + K::from_u64(65536), + ]; + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let val = self.val.singleton_value(); + + let mut pow2 = K::ONE; + for (b, c) in self.shamt_bits.iter().zip(pow2_const.iter()) { + let bit = b.singleton_value(); + pow2 *= K::ONE + bit * (*c - K::ONE); + } + + let mut carry = K::ZERO; + for (i, b) in self.carry_bits.iter().enumerate() { + carry += b.singleton_value() * K::from_u64(1u64 << i); + } + + let expr = lhs * pow2 - val - carry * two32; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. + let mut b0s: [K; 5] = [K::ZERO; 5]; + let mut b1s: [K; 5] = [K::ZERO; 5]; + for (i, b) in self.shamt_bits.iter().enumerate() { + b0s[i] = b.get(child0); + b1s[i] = b.get(child1); + } + let mut c0s: [K; 32] = [K::ZERO; 32]; + let mut c1s: [K; 32] = [K::ZERO; 32]; + for (i, b) in self.carry_bits.iter().enumerate() { + c0s[i] = b.get(child0); + c1s[i] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let val_x = interp(val0, val1, x); + + let mut pow2_x = K::ONE; + for j in 0..5 { + let b_x = interp(b0s[j], b1s[j], x); + pow2_x *= K::ONE + b_x * (pow2_const[j] - K::ONE); + } + + let mut carry_x = K::ZERO; + for j in 0..32 { + let c_x = interp(c0s[j], c1s[j], x); + carry_x += c_x * K::from_u64(1u64 << j); + } + + let expr_x = lhs_x * pow2_x - val_x - carry_x * two32; + if expr_x == K::ZERO { + continue; + } + + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + for b in self.shamt_bits.iter_mut() { + b.fold_round_in_place(r); + } + for b in self.carry_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 SRL Shout (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 packed SRL correctness: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - val(t)·2^{shamt(t)} - rem(t)) +/// +/// Where: +/// - `shamt(t)` is the shift amount (0..31), encoded as 5 Boolean columns, +/// - `rem(t)` is the remainder `lhs mod 2^{shamt}`, encoded as 32 Boolean columns, +/// - `val(t)` is the SRL result (`lhs >> shamt`). +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedSrlOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + shamt_bits: Vec>, + rem_bits: Vec>, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedSrlOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + shamt_bits: Vec>, + rem_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(shamt_bits.len(), 5); + for (i, b) in shamt_bits.iter().enumerate() { + debug_assert_eq!( + b.len(), + 1usize << ell_n, + "shamt_bits[{i}] length must match time domain" + ); + } + debug_assert_eq!(rem_bits.len(), 32); + for (i, b) in rem_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "rem_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + shamt_bits, + rem_bits, + val, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - lhs(t): multilinear (degree 1) + // - 2^{shamt(t)}: product of 5 linear terms in the shamt bits (degree 5) + // - val(t), rem(t): multilinear (degree 1) + // ⇒ total degree ≤ 1 + 1 + max(1+5, 1, 1) = 8 + degree_bound: 8, + } + } +} + +impl RoundOracle for Rv32PackedSrlOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let pow2_const: [K; 5] = [ + K::from_u64(2), + K::from_u64(4), + K::from_u64(16), + K::from_u64(256), + K::from_u64(65536), + ]; + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let val = self.val.singleton_value(); + + let mut pow2 = K::ONE; + for (b, c) in self.shamt_bits.iter().zip(pow2_const.iter()) { + let bit = b.singleton_value(); + pow2 *= K::ONE + bit * (*c - K::ONE); + } + + let mut rem = K::ZERO; + for (i, b) in self.rem_bits.iter().enumerate() { + rem += b.singleton_value() * K::from_u64(1u64 << i); + } + + let expr = lhs - val * pow2 - rem; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. + let mut b0s: [K; 5] = [K::ZERO; 5]; + let mut b1s: [K; 5] = [K::ZERO; 5]; + for (i, b) in self.shamt_bits.iter().enumerate() { + b0s[i] = b.get(child0); + b1s[i] = b.get(child1); + } + let mut r0s: [K; 32] = [K::ZERO; 32]; + let mut r1s: [K; 32] = [K::ZERO; 32]; + for (i, b) in self.rem_bits.iter().enumerate() { + r0s[i] = b.get(child0); + r1s[i] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let val_x = interp(val0, val1, x); + + let mut pow2_x = K::ONE; + for j in 0..5 { + let b_x = interp(b0s[j], b1s[j], x); + pow2_x *= K::ONE + b_x * (pow2_const[j] - K::ONE); + } + + let mut rem_x = K::ZERO; + for j in 0..32 { + let r_x = interp(r0s[j], r1s[j], x); + rem_x += r_x * K::from_u64(1u64 << j); + } + + let expr_x = lhs_x - val_x * pow2_x - rem_x; + if expr_x == K::ZERO { + continue; + } + + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + for b in self.shamt_bits.iter_mut() { + b.fold_round_in_place(r); + } + for b in self.rem_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle that enforces the SRL remainder is < 2^{shamt}: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · Σ_s eq(shamt(t), s) · Σ_{i≥s} 2^i · rem_bit_i(t) +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedSrlAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + shamt_bits: Vec>, + rem_bits: Vec>, + degree_bound: usize, +} + +impl Rv32PackedSrlAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + shamt_bits: Vec>, + rem_bits: Vec>, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(shamt_bits.len(), 5); + for (i, b) in shamt_bits.iter().enumerate() { + debug_assert_eq!( + b.len(), + 1usize << ell_n, + "shamt_bits[{i}] length must match time domain" + ); + } + debug_assert_eq!(rem_bits.len(), 32); + for (i, b) in rem_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "rem_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + shamt_bits, + rem_bits, + // Degree bound: chi (1) + gate (1) + eq_s(shamt) (5) + tail(rem) (1) = 8. + degree_bound: 8, + } + } +} + +impl RoundOracle for Rv32PackedSrlAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + + let mut shamt: [K; 5] = [K::ZERO; 5]; + for (i, b) in self.shamt_bits.iter().enumerate() { + shamt[i] = b.singleton_value(); + } + + let mut rem: [K; 32] = [K::ZERO; 32]; + for (i, b) in self.rem_bits.iter().enumerate() { + rem[i] = b.singleton_value(); + } + + // tail_sum[s] = Σ_{i≥s} 2^i · rem_i + let mut tail_sum: [K; 32] = [K::ZERO; 32]; + let mut tail = K::ZERO; + for i in (0..32).rev() { + tail += rem[i] * K::from_u64(1u64 << i); + tail_sum[i] = tail; + } + + // eq_s(shamt) for s in 0..32 + let mut eq: [K; 32] = [K::ZERO; 32]; + for s in 0..32usize { + let mut prod = K::ONE; + for j in 0..5usize { + let b = shamt[j]; + if ((s >> j) & 1) == 1 { + prod *= b; + } else { + prod *= K::ONE - b; + } + } + eq[s] = prod; + } + + let mut expr = K::ZERO; + for s in 0..32usize { + expr += eq[s] * tail_sum[s]; + } + + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. + let mut b0s: [K; 5] = [K::ZERO; 5]; + let mut b1s: [K; 5] = [K::ZERO; 5]; + for (i, b) in self.shamt_bits.iter().enumerate() { + b0s[i] = b.get(child0); + b1s[i] = b.get(child1); + } + let mut r0s: [K; 32] = [K::ZERO; 32]; + let mut r1s: [K; 32] = [K::ZERO; 32]; + for (i, b) in self.rem_bits.iter().enumerate() { + r0s[i] = b.get(child0); + r1s[i] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let mut shamt: [K; 5] = [K::ZERO; 5]; + for j in 0..5usize { + shamt[j] = interp(b0s[j], b1s[j], x); + } + + let mut rem: [K; 32] = [K::ZERO; 32]; + for j in 0..32usize { + rem[j] = interp(r0s[j], r1s[j], x); + } + + // tail_sum[s] = Σ_{i≥s} 2^i · rem_i + let mut tail_sum: [K; 32] = [K::ZERO; 32]; + let mut tail = K::ZERO; + for j in (0..32).rev() { + tail += rem[j] * K::from_u64(1u64 << j); + tail_sum[j] = tail; + } + + // eq_s(shamt) for s in 0..32 + let mut expr_x = K::ZERO; + for s in 0..32usize { + let mut prod = K::ONE; + for j in 0..5usize { + let b = shamt[j]; + if ((s >> j) & 1) == 1 { + prod *= b; + } else { + prod *= K::ONE - b; + } + } + expr_x += prod * tail_sum[s]; + } + + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + for b in self.shamt_bits.iter_mut() { + b.fold_round_in_place(r); + } + for b in self.rem_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 SRA Shout (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 packed SRA correctness: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - val(t)·2^{shamt(t)} - rem(t) - sign(t)·2^32·(1 - 2^{shamt(t)})) +/// +/// Where: +/// - `shamt(t)` is the shift amount (0..31), encoded as 5 Boolean columns, +/// - `sign(t)` is the sign bit of `lhs` (and the expected sign bit of `val`), encoded as 1 Boolean column, +/// - `rem(t)` is the remainder in the signed floor-division identity: +/// (lhs_signed) = (val_signed)·2^{shamt} + rem, with rem ∈ [0, 2^{shamt}) +/// encoded as 31 Boolean columns (bits 0..30), +/// - `val(t)` is the SRA result (`(lhs as i32) >> shamt`, represented as u32 in the field). +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedSraOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + shamt_bits: Vec>, + sign: SparseIdxVec, + rem_bits: Vec>, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedSraOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + shamt_bits: Vec>, + sign: SparseIdxVec, + rem_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(sign.len(), 1usize << ell_n); + debug_assert_eq!(shamt_bits.len(), 5); + for (i, b) in shamt_bits.iter().enumerate() { + debug_assert_eq!( + b.len(), + 1usize << ell_n, + "shamt_bits[{i}] length must match time domain" + ); + } + debug_assert_eq!(rem_bits.len(), 31); + for (i, b) in rem_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "rem_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + shamt_bits, + sign, + rem_bits, + val, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - lhs(t): multilinear (degree 1) + // - 2^{shamt(t)}: product of 5 linear terms in the shamt bits (degree 5) + // - val(t), rem(t), sign(t): multilinear (degree 1) + // ⇒ total degree ≤ 1 + 1 + max(1+5, 1, 1+5, 1) = 8 + degree_bound: 8, + } + } +} + +impl RoundOracle for Rv32PackedSraOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + let pow2_const: [K; 5] = [ + K::from_u64(2), + K::from_u64(4), + K::from_u64(16), + K::from_u64(256), + K::from_u64(65536), + ]; + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let val = self.val.singleton_value(); + let sign = self.sign.singleton_value(); + + let mut pow2 = K::ONE; + for (b, c) in self.shamt_bits.iter().zip(pow2_const.iter()) { + let bit = b.singleton_value(); + pow2 *= K::ONE + bit * (*c - K::ONE); + } + + let mut rem = K::ZERO; + for (i, b) in self.rem_bits.iter().enumerate() { + rem += b.singleton_value() * K::from_u64(1u64 << i); + } + + let corr = sign * two32 * (K::ONE - pow2); + let expr = lhs - val * pow2 - rem - corr; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + let sign0 = self.sign.get(child0); + let sign1 = self.sign.get(child1); + + // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. + let mut b0s: [K; 5] = [K::ZERO; 5]; + let mut b1s: [K; 5] = [K::ZERO; 5]; + for (i, b) in self.shamt_bits.iter().enumerate() { + b0s[i] = b.get(child0); + b1s[i] = b.get(child1); + } + let mut r0s: [K; 31] = [K::ZERO; 31]; + let mut r1s: [K; 31] = [K::ZERO; 31]; + for (i, b) in self.rem_bits.iter().enumerate() { + r0s[i] = b.get(child0); + r1s[i] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let val_x = interp(val0, val1, x); + let sign_x = interp(sign0, sign1, x); + + let mut pow2_x = K::ONE; + for j in 0..5 { + let b_x = interp(b0s[j], b1s[j], x); + pow2_x *= K::ONE + b_x * (pow2_const[j] - K::ONE); + } + + let mut rem_x = K::ZERO; + for j in 0..31 { + let r_x = interp(r0s[j], r1s[j], x); + rem_x += r_x * K::from_u64(1u64 << j); + } + + let corr_x = sign_x * two32 * (K::ONE - pow2_x); + let expr_x = lhs_x - val_x * pow2_x - rem_x - corr_x; + if expr_x == K::ZERO { + continue; + } + + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + for b in self.shamt_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.sign.fold_round_in_place(r); + for b in self.rem_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle that enforces the SRA remainder is < 2^{shamt}: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · Σ_s eq(shamt(t), s) · Σ_{i≥s} 2^i · rem_bit_i(t) +/// +/// For SRA, we only carry 31 remainder bits (0..30). This is sufficient because `shamt ∈ [0,31]` +/// and the remainder range is always `< 2^{shamt} ≤ 2^{31}`. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedSraAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + shamt_bits: Vec>, + rem_bits: Vec>, + degree_bound: usize, +} + +impl Rv32PackedSraAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + shamt_bits: Vec>, + rem_bits: Vec>, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(shamt_bits.len(), 5); + for (i, b) in shamt_bits.iter().enumerate() { + debug_assert_eq!( + b.len(), + 1usize << ell_n, + "shamt_bits[{i}] length must match time domain" + ); + } + debug_assert_eq!(rem_bits.len(), 31); + for (i, b) in rem_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "rem_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + shamt_bits, + rem_bits, + // Degree bound: chi (1) + gate (1) + eq_s(shamt) (5) + tail(rem) (1) = 8. + degree_bound: 8, + } + } +} + +impl RoundOracle for Rv32PackedSraAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + + let mut shamt: [K; 5] = [K::ZERO; 5]; + for (i, b) in self.shamt_bits.iter().enumerate() { + shamt[i] = b.singleton_value(); + } + + let mut rem: [K; 31] = [K::ZERO; 31]; + for (i, b) in self.rem_bits.iter().enumerate() { + rem[i] = b.singleton_value(); + } + + // tail_sum[s] = Σ_{i≥s} 2^i · rem_i, with tail_sum[31]=0 (no bits >= 31). + let mut tail_sum: [K; 32] = [K::ZERO; 32]; + let mut tail = K::ZERO; + for i in (0..31).rev() { + tail += rem[i] * K::from_u64(1u64 << i); + tail_sum[i] = tail; + } + tail_sum[31] = K::ZERO; + + let mut expr = K::ZERO; + for s in 0..32usize { + let mut prod = K::ONE; + for j in 0..5usize { + let b = shamt[j]; + if ((s >> j) & 1) == 1 { + prod *= b; + } else { + prod *= K::ONE - b; + } + } + expr += prod * tail_sum[s]; + } + + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. + let mut b0s: [K; 5] = [K::ZERO; 5]; + let mut b1s: [K; 5] = [K::ZERO; 5]; + for (i, b) in self.shamt_bits.iter().enumerate() { + b0s[i] = b.get(child0); + b1s[i] = b.get(child1); + } + let mut r0s: [K; 31] = [K::ZERO; 31]; + let mut r1s: [K; 31] = [K::ZERO; 31]; + for (i, b) in self.rem_bits.iter().enumerate() { + r0s[i] = b.get(child0); + r1s[i] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let mut shamt: [K; 5] = [K::ZERO; 5]; + for j in 0..5usize { + shamt[j] = interp(b0s[j], b1s[j], x); + } + + let mut rem: [K; 31] = [K::ZERO; 31]; + for j in 0..31usize { + rem[j] = interp(r0s[j], r1s[j], x); + } + + // tail_sum[s] = Σ_{i≥s} 2^i · rem_i, with tail_sum[31]=0 (no bits >= 31). + let mut tail_sum: [K; 32] = [K::ZERO; 32]; + let mut tail = K::ZERO; + for j in (0..31).rev() { + tail += rem[j] * K::from_u64(1u64 << j); + tail_sum[j] = tail; + } + tail_sum[31] = K::ZERO; + + let mut expr_x = K::ZERO; + for s in 0..32usize { + let mut prod = K::ONE; + for j in 0..5usize { + let b = shamt[j]; + if ((s >> j) & 1) == 1 { + prod *= b; + } else { + prod *= K::ONE - b; + } + } + expr_x += prod * tail_sum[s]; + } + + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + for b in self.shamt_bits.iter_mut() { + b.fold_round_in_place(r); + } + for b in self.rem_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 DIV*/REM* Shout (time-domain) +// ============================================================================ + +/// Sparse Route A oracle for RV32 packed DIVU correctness: +/// Σ_t χ(t)·has_lookup(t)·( z(t)·(quot(t) - 0xFFFF_FFFF) + (1 - z(t))·(lhs(t) - rhs(t)·quot(t) - rem(t)) ) +/// +/// Where: +/// - `quot(t)` is the DIVU output (lane.val), +/// - `rem(t)` is an auxiliary witness column, +/// - `z(t)` is a witness bit intended to be 1 iff `rhs(t) == 0`. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedDivuOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + rem: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + quot: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedDivuOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + rem: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + quot: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(rem.len(), 1usize << ell_n); + debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(quot.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + rem, + rhs_is_zero, + quot, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - rhs(t)·quot(t): degree 2 + // - gated by (1 - z(t)): adds 1 + // ⇒ total degree ≤ 1 + 1 + 2 + 1 = 5 + degree_bound: 5, + } + } +} + +impl RoundOracle for Rv32PackedDivuOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let all_ones = K::from_u64(u32::MAX as u64); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let rem = self.rem.singleton_value(); + let z = self.rhs_is_zero.singleton_value(); + let quot = self.quot.singleton_value(); + + let expr = z * (quot - all_ones) + (K::ONE - z) * (lhs - rhs * quot - rem); + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let rem0 = self.rem.get(child0); + let rem1 = self.rem.get(child1); + let z0 = self.rhs_is_zero.get(child0); + let z1 = self.rhs_is_zero.get(child1); + let quot0 = self.quot.get(child0); + let quot1 = self.quot.get(child1); + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let rhs_x = interp(rhs0, rhs1, x); + let rem_x = interp(rem0, rem1, x); + let z_x = interp(z0, z1, x); + let quot_x = interp(quot0, quot1, x); + + let expr_x = z_x * (quot_x - all_ones) + (K::ONE - z_x) * (lhs_x - rhs_x * quot_x - rem_x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.rem.fold_round_in_place(r); + self.rhs_is_zero.fold_round_in_place(r); + self.quot.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed REMU correctness: +/// Σ_t χ(t)·has_lookup(t)·( z(t)·(rem(t) - lhs(t)) + (1 - z(t))·(lhs(t) - rhs(t)·quot(t) - rem(t)) ) +/// +/// Where: +/// - `rem(t)` is the REMU output (lane.val), +/// - `quot(t)` is an auxiliary witness column, +/// - `z(t)` is a witness bit intended to be 1 iff `rhs(t) == 0`. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedRemuOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + quot: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + rem: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedRemuOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + quot: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + rem: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(quot.len(), 1usize << ell_n); + debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(rem.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + quot, + rhs_is_zero, + rem, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - rhs(t)·quot(t): degree 2 + // - gated by (1 - z(t)): adds 1 + // ⇒ total degree ≤ 1 + 1 + 2 + 1 = 5 + degree_bound: 5, + } + } +} + +impl RoundOracle for Rv32PackedRemuOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let quot = self.quot.singleton_value(); + let z = self.rhs_is_zero.singleton_value(); + let rem = self.rem.singleton_value(); + + let expr = z * (rem - lhs) + (K::ONE - z) * (lhs - rhs * quot - rem); + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let quot0 = self.quot.get(child0); + let quot1 = self.quot.get(child1); + let z0 = self.rhs_is_zero.get(child0); + let z1 = self.rhs_is_zero.get(child1); + let rem0 = self.rem.get(child0); + let rem1 = self.rem.get(child1); + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let rhs_x = interp(rhs0, rhs1, x); + let quot_x = interp(quot0, quot1, x); + let z_x = interp(z0, z1, x); + let rem_x = interp(rem0, rem1, x); + + let expr_x = z_x * (rem_x - lhs_x) + (K::ONE - z_x) * (lhs_x - rhs_x * quot_x - rem_x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.quot.fold_round_in_place(r); + self.rhs_is_zero.fold_round_in_place(r); + self.rem.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed DIVU/REMU helpers (adapter): +/// Σ_t χ(t)·has_lookup(t)·Σ_i w_i · c_i(t) +/// +/// Constraints: +/// - `c0 = rhs_is_zero·(1 - rhs_is_zero)` (boolean helper; redundant with bitness) +/// - `c1 = rhs_is_zero·rhs` (rhs_is_zero => rhs==0) +/// - `c2 = (1 - rhs_is_zero)·(rem - rhs - diff + 2^32)` (remainder bound) +/// - `c3 = diff - Σ 2^i·diff_bit_i` (u32 decomposition) +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedDivRemuAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + rhs: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + rem: SparseIdxVec, + diff: SparseIdxVec, + diff_bits: Vec>, + weights: [K; 4], + degree_bound: usize, +} + +impl Rv32PackedDivRemuAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + rhs: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + rem: SparseIdxVec, + diff: SparseIdxVec, + diff_bits: Vec>, + weights: [K; 4], + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(rem.len(), 1usize << ell_n); + debug_assert_eq!(diff.len(), 1usize << ell_n); + debug_assert_eq!(diff_bits.len(), 32); + for (i, b) in diff_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + rhs, + rhs_is_zero, + rem, + diff, + diff_bits, + weights, + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - remainder bound term multiplies by (1 - rhs_is_zero): degree 2 + // ⇒ total degree ≤ 1 + 1 + 2 = 4 + degree_bound: 4, + } + } +} + +impl RoundOracle for Rv32PackedDivRemuAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + let w0 = self.weights[0]; + let w1 = self.weights[1]; + let w2 = self.weights[2]; + let w3 = self.weights[3]; + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let rhs = self.rhs.singleton_value(); + let z = self.rhs_is_zero.singleton_value(); + let rem = self.rem.singleton_value(); + let diff = self.diff.singleton_value(); + + let mut sum = K::ZERO; + for (i, b) in self.diff_bits.iter().enumerate() { + sum += b.singleton_value() * K::from_u64(1u64 << i); + } + + let c0 = z * (K::ONE - z); + let c1 = z * rhs; + let c2 = (K::ONE - z) * (rem - rhs - diff + two32); + let c3 = diff - sum; + + let expr = w0 * c0 + w1 * c1 + w2 * c2 + w3 * c3; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let z0 = self.rhs_is_zero.get(child0); + let z1 = self.rhs_is_zero.get(child1); + let rem0 = self.rem.get(child0); + let rem1 = self.rem.get(child1); + let diff0 = self.diff.get(child0); + let diff1 = self.diff.get(child1); + + let mut b0s: [K; 32] = [K::ZERO; 32]; + let mut b1s: [K; 32] = [K::ZERO; 32]; + for (j, b) in self.diff_bits.iter().enumerate() { + b0s[j] = b.get(child0); + b1s[j] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let rhs_x = interp(rhs0, rhs1, x); + let z_x = interp(z0, z1, x); + let rem_x = interp(rem0, rem1, x); + let diff_x = interp(diff0, diff1, x); + + let mut sum = K::ZERO; + for j in 0..32 { + let b_x = interp(b0s[j], b1s[j], x); + sum += b_x * K::from_u64(1u64 << j); + } + + let c0 = z_x * (K::ONE - z_x); + let c1 = z_x * rhs_x; + let c2 = (K::ONE - z_x) * (rem_x - rhs_x - diff_x + two32); + let c3 = diff_x - sum; + let expr_x = w0 * c0 + w1 * c1 + w2 * c2 + w3 * c3; + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.rhs_is_zero.fold_round_in_place(r); + self.rem.fold_round_in_place(r); + self.diff.fold_round_in_place(r); + for b in self.diff_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed DIV correctness (signed quotient output). +/// +/// Uses auxiliary `q_abs` (unsigned quotient), sign bits, and a `q_is_zero` witness bit to +/// handle the `-0` normalization case. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedDivOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + q_abs: SparseIdxVec, + q_is_zero: SparseIdxVec, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedDivOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + q_abs: SparseIdxVec, + q_is_zero: SparseIdxVec, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(rhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(q_abs.len(), 1usize << ell_n); + debug_assert_eq!(q_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs_sign, + rhs_sign, + rhs_is_zero, + q_abs, + q_is_zero, + val, + // This oracle composes abs-quotient sign logic (degree 4) with rhs_is_zero gating. + degree_bound: 7, + } + } +} + +impl RoundOracle for Rv32PackedDivOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two = K::from_u64(2); + let two32 = K::from_u64(1u64 << 32); + let all_ones = K::from_u64(u32::MAX as u64); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let s1 = self.lhs_sign.singleton_value(); + let s2 = self.rhs_sign.singleton_value(); + let z = self.rhs_is_zero.singleton_value(); + let q_abs = self.q_abs.singleton_value(); + let q0 = self.q_is_zero.singleton_value(); + let val = self.val.singleton_value(); + + let div_sign = s1 + s2 - two * s1 * s2; + let neg_q = (K::ONE - q0) * (two32 - q_abs); + let q_signed = (K::ONE - div_sign) * q_abs + div_sign * neg_q; + + let expr = z * (val - all_ones) + (K::ONE - z) * (val - q_signed); + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let s10 = self.lhs_sign.get(child0); + let s11 = self.lhs_sign.get(child1); + let s20 = self.rhs_sign.get(child0); + let s21 = self.rhs_sign.get(child1); + let z0 = self.rhs_is_zero.get(child0); + let z1 = self.rhs_is_zero.get(child1); + let q0 = self.q_abs.get(child0); + let q1 = self.q_abs.get(child1); + let qz0 = self.q_is_zero.get(child0); + let qz1 = self.q_is_zero.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let s1 = interp(s10, s11, x); + let s2 = interp(s20, s21, x); + let z = interp(z0, z1, x); + let q_abs = interp(q0, q1, x); + let qz = interp(qz0, qz1, x); + let val = interp(val0, val1, x); + + let div_sign = s1 + s2 - two * s1 * s2; + let neg_q = (K::ONE - qz) * (two32 - q_abs); + let q_signed = (K::ONE - div_sign) * q_abs + div_sign * neg_q; + + let expr_x = z * (val - all_ones) + (K::ONE - z) * (val - q_signed); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs_sign.fold_round_in_place(r); + self.rhs_sign.fold_round_in_place(r); + self.rhs_is_zero.fold_round_in_place(r); + self.q_abs.fold_round_in_place(r); + self.q_is_zero.fold_round_in_place(r); + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed REM correctness (signed remainder output). +/// +/// Uses auxiliary `r_abs` (unsigned remainder), the dividend sign bit, and `r_is_zero` to handle `-0`. +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct Rv32PackedRemOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + r_abs: SparseIdxVec, + r_is_zero: SparseIdxVec, + val: SparseIdxVec, + degree_bound: usize, +} + +impl Rv32PackedRemOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + r_abs: SparseIdxVec, + r_is_zero: SparseIdxVec, + val: SparseIdxVec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(r_abs.len(), 1usize << ell_n); + debug_assert_eq!(r_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + lhs_sign, + rhs_is_zero, + r_abs, + r_is_zero, + val, + degree_bound: 7, + } + } +} + +impl RoundOracle for Rv32PackedRemOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two32 = K::from_u64(1u64 << 32); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let s = self.lhs_sign.singleton_value(); + let z = self.rhs_is_zero.singleton_value(); + let r_abs = self.r_abs.singleton_value(); + let r0 = self.r_is_zero.singleton_value(); + let val = self.val.singleton_value(); + + let neg_r = (K::ONE - r0) * (two32 - r_abs); + let r_signed = (K::ONE - s) * r_abs + s * neg_r; + let expr = z * (val - lhs) + (K::ONE - z) * (val - r_signed); + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let s0 = self.lhs_sign.get(child0); + let s1 = self.lhs_sign.get(child1); + let z0 = self.rhs_is_zero.get(child0); + let z1 = self.rhs_is_zero.get(child1); + let r0_0 = self.r_abs.get(child0); + let r0_1 = self.r_abs.get(child1); + let rz0 = self.r_is_zero.get(child0); + let rz1 = self.r_is_zero.get(child1); + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let s_x = interp(s0, s1, x); + let z_x = interp(z0, z1, x); + let r_abs_x = interp(r0_0, r0_1, x); + let rz_x = interp(rz0, rz1, x); + let val_x = interp(val0, val1, x); + + let neg_r = (K::ONE - rz_x) * (two32 - r_abs_x); + let r_signed = (K::ONE - s_x) * r_abs_x + s_x * neg_r; + let expr_x = z_x * (val_x - lhs_x) + (K::ONE - z_x) * (val_x - r_signed); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.lhs_sign.fold_round_in_place(r); + self.rhs_is_zero.fold_round_in_place(r); + self.r_abs.fold_round_in_place(r); + self.r_is_zero.fold_round_in_place(r); + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for RV32 packed DIV/REM helpers (adapter). +/// +/// This enforces: +/// - zero-detection for `rhs`, +/// - zero-detection for a signed-output magnitude (`q_abs` for DIV or `r_abs` for REM) to handle `-0`, +/// - absolute-value division equation over u32 values when `rhs != 0`, +/// - remainder bound `r_abs < |rhs|` when `rhs != 0`, +/// - u32 decomposition of the remainder-bound witness `diff`. +pub struct Rv32PackedDivRemAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, + q_abs: SparseIdxVec, + r_abs: SparseIdxVec, + mag: SparseIdxVec, + mag_is_zero: SparseIdxVec, + diff: SparseIdxVec, + diff_bits: Vec>, + weights: [K; 7], + degree_bound: usize, +} + +impl Rv32PackedDivRemAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + rhs_is_zero: SparseIdxVec, + lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, + q_abs: SparseIdxVec, + r_abs: SparseIdxVec, + mag: SparseIdxVec, + mag_is_zero: SparseIdxVec, + diff: SparseIdxVec, + diff_bits: Vec>, + weights: [K; 7], + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(rhs_sign.len(), 1usize << ell_n); + debug_assert_eq!(q_abs.len(), 1usize << ell_n); + debug_assert_eq!(r_abs.len(), 1usize << ell_n); + debug_assert_eq!(mag.len(), 1usize << ell_n); + debug_assert_eq!(mag_is_zero.len(), 1usize << ell_n); + debug_assert_eq!(diff.len(), 1usize << ell_n); + debug_assert_eq!(diff_bits.len(), 32); + for (i, b) in diff_bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); + } + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs, + rhs, + rhs_is_zero, + lhs_sign, + rhs_sign, + q_abs, + r_abs, + mag, + mag_is_zero, + diff, + diff_bits, + weights, + degree_bound: 6, + } + } +} + +impl RoundOracle for Rv32PackedDivRemAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let two = K::from_u64(2); + let two32 = K::from_u64(1u64 << 32); + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + let z = self.rhs_is_zero.singleton_value(); + let lhs_sign = self.lhs_sign.singleton_value(); + let rhs_sign = self.rhs_sign.singleton_value(); + let q_abs = self.q_abs.singleton_value(); + let r_abs = self.r_abs.singleton_value(); + let mag = self.mag.singleton_value(); + let mag_z = self.mag_is_zero.singleton_value(); + let diff = self.diff.singleton_value(); + + let mut sum = K::ZERO; + for (i, b) in self.diff_bits.iter().enumerate() { + sum += b.singleton_value() * K::from_u64(1u64 << i); + } + + let lhs_abs = lhs + lhs_sign * (two32 - two * lhs); + let rhs_abs = rhs + rhs_sign * (two32 - two * rhs); + + let c0 = z * (K::ONE - z); + let c1 = z * rhs; + let c2 = mag_z * (K::ONE - mag_z); + let c3 = mag_z * mag; + let c4 = (K::ONE - z) * (lhs_abs - rhs_abs * q_abs - r_abs); + let c5 = (K::ONE - z) * (r_abs - rhs_abs - diff + two32); + let c6 = diff - sum; + + let w = &self.weights; + let expr = w[0] * c0 + w[1] * c1 + w[2] * c2 + w[3] * c3 + w[4] * c4 + w[5] * c5 + w[6] * c6; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + let z0 = self.rhs_is_zero.get(child0); + let z1 = self.rhs_is_zero.get(child1); + let lhs_sign0 = self.lhs_sign.get(child0); + let lhs_sign1 = self.lhs_sign.get(child1); + let rhs_sign0 = self.rhs_sign.get(child0); + let rhs_sign1 = self.rhs_sign.get(child1); + let q0 = self.q_abs.get(child0); + let q1 = self.q_abs.get(child1); + let r0 = self.r_abs.get(child0); + let r1 = self.r_abs.get(child1); + let mag0 = self.mag.get(child0); + let mag1 = self.mag.get(child1); + let mag_z0 = self.mag_is_zero.get(child0); + let mag_z1 = self.mag_is_zero.get(child1); + let diff0 = self.diff.get(child0); + let diff1 = self.diff.get(child1); + + let mut b0s: [K; 32] = [K::ZERO; 32]; + let mut b1s: [K; 32] = [K::ZERO; 32]; + for (j, b) in self.diff_bits.iter().enumerate() { + b0s[j] = b.get(child0); + b1s[j] = b.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs = interp(lhs0, lhs1, x); + let rhs = interp(rhs0, rhs1, x); + let z = interp(z0, z1, x); + let lhs_sign = interp(lhs_sign0, lhs_sign1, x); + let rhs_sign = interp(rhs_sign0, rhs_sign1, x); + let q_abs = interp(q0, q1, x); + let r_abs = interp(r0, r1, x); + let mag = interp(mag0, mag1, x); + let mag_z = interp(mag_z0, mag_z1, x); + let diff = interp(diff0, diff1, x); + + let mut sum = K::ZERO; + for j in 0..32 { + let b_x = interp(b0s[j], b1s[j], x); + sum += b_x * K::from_u64(1u64 << j); + } + + let lhs_abs = lhs + lhs_sign * (two32 - two * lhs); + let rhs_abs = rhs + rhs_sign * (two32 - two * rhs); + + let c0 = z * (K::ONE - z); + let c1 = z * rhs; + let c2 = mag_z * (K::ONE - mag_z); + let c3 = mag_z * mag; + let c4 = (K::ONE - z) * (lhs_abs - rhs_abs * q_abs - r_abs); + let c5 = (K::ONE - z) * (r_abs - rhs_abs - diff + two32); + let c6 = diff - sum; + + let w = &self.weights; + let expr_x = w[0] * c0 + w[1] * c1 + w[2] * c2 + w[3] * c3 + w[4] * c4 + w[5] * c5 + w[6] * c6; + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + self.rhs_is_zero.fold_round_in_place(r); + self.lhs_sign.fold_round_in_place(r); + self.rhs_sign.fold_round_in_place(r); + self.q_abs.fold_round_in_place(r); + self.r_abs.fold_round_in_place(r); + self.mag.fold_round_in_place(r); + self.mag_is_zero.fold_round_in_place(r); + self.diff.fold_round_in_place(r); + for b in self.diff_bits.iter_mut() { + b.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +// ============================================================================ +// Packed-key RV32 bitwise Shout (AND/OR/XOR) via 2-bit digits (time-domain) +// ============================================================================ + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Rv32PackedBitwiseOp2 { + And, + Andn, + Or, + Xor, +} + +#[inline] +fn rv32_two_bit_digit_bits(inv2: K, inv6: K, x: K) -> (K, K) { + // Bits for x in {0,1,2,3}, represented as degree-3 polynomials over the field: + // - bit0: 0,1,0,1 + // - bit1: 0,0,1,1 + // + // Using Lagrange basis on {0,1,2,3}: + // L1(x) = x(x-2)(x-3)/2 + // L2(x) = -x(x-1)(x-3)/2 + // L3(x) = x(x-1)(x-2)/6 + // + // Then: + // bit0(x) = L1(x) + L3(x) + // bit1(x) = L2(x) + L3(x) + let xm1 = x - K::ONE; + let xm2 = x - K::from_u64(2); + let xm3 = x - K::from_u64(3); + + let x_xm1 = x * xm1; + let l1 = (x * xm2 * xm3) * inv2; + let l3 = (x_xm1 * xm2) * inv6; + let l2 = -(x_xm1 * xm3) * inv2; + + let bit0 = l1 + l3; + let bit1 = l2 + l3; + (bit0, bit1) +} + +#[inline] +fn rv32_two_bit_digit_op(inv2: K, inv6: K, op: Rv32PackedBitwiseOp2, a: K, b: K) -> K { + let (a0, a1) = rv32_two_bit_digit_bits(inv2, inv6, a); + let (b0, b1) = rv32_two_bit_digit_bits(inv2, inv6, b); + + let two = K::from_u64(2); + match op { + Rv32PackedBitwiseOp2::And => { + let r0 = a0 * b0; + let r1 = a1 * b1; + r0 + two * r1 + } + Rv32PackedBitwiseOp2::Andn => { + let r0 = a0 * (K::ONE - b0); + let r1 = a1 * (K::ONE - b1); + r0 + two * r1 + } + Rv32PackedBitwiseOp2::Or => { + let r0 = a0 + b0 - a0 * b0; + let r1 = a1 + b1 - a1 * b1; + r0 + two * r1 + } + Rv32PackedBitwiseOp2::Xor => { + let r0 = a0 + b0 - two * a0 * b0; + let r1 = a1 + b1 - two * a1 * b1; + r0 + two * r1 + } + } +} + +#[inline] +fn rv32_digit4_range_poly(x: K) -> K { + // Vanishes exactly on {0,1,2,3}. + x * (x - K::ONE) * (x - K::from_u64(2)) * (x - K::from_u64(3)) +} + +pub struct Rv32PackedBitwiseAdapterOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + degree_bound: usize, + + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + weights: Vec, +} + +impl Rv32PackedBitwiseAdapterOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + weights: Vec, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(lhs.len(), 1usize << ell_n); + debug_assert_eq!(rhs.len(), 1usize << ell_n); + debug_assert_eq!(lhs_digits.len(), 16); + debug_assert_eq!(rhs_digits.len(), 16); + for (i, d) in lhs_digits.iter().enumerate() { + debug_assert_eq!( + d.len(), + 1usize << ell_n, + "lhs_digits[{i}] length must match time domain" + ); + } + for (i, d) in rhs_digits.iter().enumerate() { + debug_assert_eq!( + d.len(), + 1usize << ell_n, + "rhs_digits[{i}] length must match time domain" + ); + } + debug_assert_eq!(weights.len(), 34); + + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - digit4 range poly: degree 4 + // ⇒ total degree ≤ 1 + 1 + 4 = 6 + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + degree_bound: 6, + has_lookup, + lhs, + rhs, + lhs_digits, + rhs_digits, + weights, + } + } +} + +impl RoundOracle for Rv32PackedBitwiseAdapterOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + let w_lhs = self.weights[0]; + let w_rhs = self.weights[1]; + let w_digits = &self.weights[2..]; + + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let lhs = self.lhs.singleton_value(); + let rhs = self.rhs.singleton_value(); + + let mut lhs_recon = K::ZERO; + let mut rhs_recon = K::ZERO; + for i in 0..16usize { + lhs_recon += self.lhs_digits[i].singleton_value() * K::from_u64(1u64 << (2 * i)); + rhs_recon += self.rhs_digits[i].singleton_value() * K::from_u64(1u64 << (2 * i)); + } + + let mut range_sum = K::ZERO; + for (i, d) in self.lhs_digits.iter().enumerate() { + range_sum += w_digits[i] * rv32_digit4_range_poly(d.singleton_value()); + } + for (i, d) in self.rhs_digits.iter().enumerate() { + range_sum += w_digits[16 + i] * rv32_digit4_range_poly(d.singleton_value()); + } + + let expr = w_lhs * (lhs - lhs_recon) + w_rhs * (rhs - rhs_recon) + range_sum; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let lhs0 = self.lhs.get(child0); + let lhs1 = self.lhs.get(child1); + let rhs0 = self.rhs.get(child0); + let rhs1 = self.rhs.get(child1); + + let mut a0s: [K; 16] = [K::ZERO; 16]; + let mut a1s: [K; 16] = [K::ZERO; 16]; + for (i, d) in self.lhs_digits.iter().enumerate() { + a0s[i] = d.get(child0); + a1s[i] = d.get(child1); + } + let mut b0s: [K; 16] = [K::ZERO; 16]; + let mut b1s: [K; 16] = [K::ZERO; 16]; + for (i, d) in self.rhs_digits.iter().enumerate() { + b0s[i] = d.get(child0); + b1s[i] = d.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let lhs_x = interp(lhs0, lhs1, x); + let rhs_x = interp(rhs0, rhs1, x); + + let mut lhs_recon = K::ZERO; + let mut rhs_recon = K::ZERO; + let mut range_sum = K::ZERO; + for j in 0..16usize { + let aj = interp(a0s[j], a1s[j], x); + let bj = interp(b0s[j], b1s[j], x); + let pow = K::from_u64(1u64 << (2 * j)); + lhs_recon += aj * pow; + rhs_recon += bj * pow; + + range_sum += w_digits[j] * rv32_digit4_range_poly(aj); + range_sum += w_digits[16 + j] * rv32_digit4_range_poly(bj); + } + + let expr = w_lhs * (lhs_x - lhs_recon) + w_rhs * (rhs_x - rhs_recon) + range_sum; + ys[i] += chi_x * gate_x * expr; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.lhs.fold_round_in_place(r); + self.rhs.fold_round_in_place(r); + for d in self.lhs_digits.iter_mut() { + d.fold_round_in_place(r); + } + for d in self.rhs_digits.iter_mut() { + d.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +pub struct Rv32PackedBitwiseOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + val: SparseIdxVec, + op: Rv32PackedBitwiseOp2, + inv2: K, + inv6: K, + degree_bound: usize, +} + +impl Rv32PackedBitwiseOracleSparseTime { + fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + val: SparseIdxVec, + op: Rv32PackedBitwiseOp2, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(val.len(), 1usize << ell_n); + debug_assert_eq!(lhs_digits.len(), 16); + debug_assert_eq!(rhs_digits.len(), 16); + for (i, d) in lhs_digits.iter().enumerate() { + debug_assert_eq!( + d.len(), + 1usize << ell_n, + "lhs_digits[{i}] length must match time domain" + ); + } + for (i, d) in rhs_digits.iter().enumerate() { + debug_assert_eq!( + d.len(), + 1usize << ell_n, + "rhs_digits[{i}] length must match time domain" + ); + } + + // Degree bound: + // - chi(t): multilinear (degree 1) + // - has_lookup(t): multilinear (degree 1) + // - bitwise digit op: degree 6 (two degree-3 bit extractors multiplied) + // ⇒ total degree ≤ 1 + 1 + 6 = 8 + let inv2 = K::from_u64(2).inverse(); + let inv6 = K::from_u64(6).inverse(); + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + lhs_digits, + rhs_digits, + val, + op, + inv2, + inv6, + degree_bound: 8, + } + } +} + +impl RoundOracle for Rv32PackedBitwiseOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let val = self.val.singleton_value(); + + let mut out = K::ZERO; + for i in 0..16usize { + let a = self.lhs_digits[i].singleton_value(); + let b = self.rhs_digits[i].singleton_value(); + let digit = rv32_two_bit_digit_op(self.inv2, self.inv6, self.op, a, b); + out += digit * K::from_u64(1u64 << (2 * i)); + } + let expr = out - val; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let val0 = self.val.get(child0); + let val1 = self.val.get(child1); + + let mut a0s: [K; 16] = [K::ZERO; 16]; + let mut a1s: [K; 16] = [K::ZERO; 16]; + for (j, d) in self.lhs_digits.iter().enumerate() { + a0s[j] = d.get(child0); + a1s[j] = d.get(child1); + } + let mut b0s: [K; 16] = [K::ZERO; 16]; + let mut b1s: [K; 16] = [K::ZERO; 16]; + for (j, d) in self.rhs_digits.iter().enumerate() { + b0s[j] = d.get(child0); + b1s[j] = d.get(child1); + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + + let val_x = interp(val0, val1, x); + + let mut out = K::ZERO; + for j in 0..16usize { + let a = interp(a0s[j], a1s[j], x); + let b = interp(b0s[j], b1s[j], x); + let digit = rv32_two_bit_digit_op(self.inv2, self.inv6, self.op, a, b); + out += digit * K::from_u64(1u64 << (2 * j)); + } + + let expr = out - val_x; + ys[i] += chi_x * gate_x * expr; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + for d in self.lhs_digits.iter_mut() { + d.fold_round_in_place(r); + } + for d in self.rhs_digits.iter_mut() { + d.fold_round_in_place(r); + } + self.val.fold_round_in_place(r); + self.bit_idx += 1; + } +} + +pub struct Rv32PackedAndOracleSparseTime { + core: Rv32PackedBitwiseOracleSparseTime, +} +impl Rv32PackedAndOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + val: SparseIdxVec, + ) -> Self { + Self { + core: Rv32PackedBitwiseOracleSparseTime::new( + r_cycle, + has_lookup, + lhs_digits, + rhs_digits, + val, + Rv32PackedBitwiseOp2::And, + ), + } + } +} +impl_round_oracle_via_core!(Rv32PackedAndOracleSparseTime); + +pub struct Rv32PackedAndnOracleSparseTime { + core: Rv32PackedBitwiseOracleSparseTime, +} +impl Rv32PackedAndnOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + val: SparseIdxVec, + ) -> Self { + Self { + core: Rv32PackedBitwiseOracleSparseTime::new( + r_cycle, + has_lookup, + lhs_digits, + rhs_digits, + val, + Rv32PackedBitwiseOp2::Andn, + ), + } + } +} +impl_round_oracle_via_core!(Rv32PackedAndnOracleSparseTime); + +pub struct Rv32PackedOrOracleSparseTime { + core: Rv32PackedBitwiseOracleSparseTime, +} +impl Rv32PackedOrOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + val: SparseIdxVec, + ) -> Self { + Self { + core: Rv32PackedBitwiseOracleSparseTime::new( + r_cycle, + has_lookup, + lhs_digits, + rhs_digits, + val, + Rv32PackedBitwiseOp2::Or, + ), + } + } +} +impl_round_oracle_via_core!(Rv32PackedOrOracleSparseTime); + +pub struct Rv32PackedXorOracleSparseTime { + core: Rv32PackedBitwiseOracleSparseTime, +} +impl Rv32PackedXorOracleSparseTime { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + val: SparseIdxVec, + ) -> Self { + Self { + core: Rv32PackedBitwiseOracleSparseTime::new( + r_cycle, + has_lookup, + lhs_digits, + rhs_digits, + val, + Rv32PackedBitwiseOp2::Xor, + ), + } + } +} +impl_round_oracle_via_core!(Rv32PackedXorOracleSparseTime); + +/// Sparse Route A oracle for u32 bit-decomposition consistency: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (x(t) - Σ_i 2^i · bit_i(t)) +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct U32DecompOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + x: SparseIdxVec, + bits: Vec>, + weights: Vec, + degree_bound: usize, +} + +impl U32DecompOracleSparseTime { + pub fn new(r_cycle: &[K], has_lookup: SparseIdxVec, x: SparseIdxVec, bits: Vec>) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(x.len(), 1usize << ell_n); + debug_assert_eq!(bits.len(), 32); + for (i, b) in bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "bits[{i}] length must match time domain"); + } + let weights: Vec = (0..32).map(|i| K::from_u64(1u64 << i)).collect(); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + x, + bits, + weights, + degree_bound: 3, + } + } +} + +impl RoundOracle for U32DecompOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let x = self.x.singleton_value(); + let mut sum = K::ZERO; + for (b, w) in self.bits.iter().zip(self.weights.iter()) { + sum += b.singleton_value() * *w; + } + let expr = x - sum; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let x0 = self.x.get(child0); + let x1 = self.x.get(child1); + + let mut expr0 = x0; + let mut expr1 = x1; + for (b_col, w) in self.bits.iter().zip(self.weights.iter()) { + let b0 = b_col.get(child0); + let b1 = b_col.get(child1); + expr0 -= b0 * *w; + expr1 -= b1 * *w; + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let expr_x = interp(expr0, expr1, x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.x.fold_round_in_place(r); + for b in self.bits.iter_mut() { + b.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +/// Sparse Route A oracle for u5 bit-decomposition consistency: +/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (x(t) - Σ_i 2^i · bit_i(t)) +/// +/// Intended usage: set the claimed sum to 0 to enforce correctness. +pub struct U5DecompOracleSparseTime { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + x: SparseIdxVec, + bits: Vec>, + weights: Vec, + degree_bound: usize, +} + +impl U5DecompOracleSparseTime { + pub fn new(r_cycle: &[K], has_lookup: SparseIdxVec, x: SparseIdxVec, bits: Vec>) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + debug_assert_eq!(x.len(), 1usize << ell_n); + debug_assert_eq!(bits.len(), 5); + for (i, b) in bits.iter().enumerate() { + debug_assert_eq!(b.len(), 1usize << ell_n, "bits[{i}] length must match time domain"); + } + let weights: Vec = (0..5).map(|i| K::from_u64(1u64 << i)).collect(); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + x, + bits, + weights, + degree_bound: 3, + } + } +} + +impl RoundOracle for U5DecompOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + if self.has_lookup.len() == 1 { + let gate = self.has_lookup.singleton_value(); + let x = self.x.singleton_value(); + let mut sum = K::ZERO; + for (b, w) in self.bits.iter().zip(self.weights.iter()) { + sum += b.singleton_value() * *w; + } + let expr = x - sum; + let v = self.prefix_eq * gate * expr; + return vec![v; points.len()]; + } + + let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); + let half = self.has_lookup.len() / 2; + debug_assert!(pairs.iter().all(|&p| p < half)); + + let mut ys = vec![K::ZERO; points.len()]; + for &pair in pairs.iter() { + let child0 = 2 * pair; + let child1 = child0 + 1; + + let gate0 = self.has_lookup.get(child0); + let gate1 = self.has_lookup.get(child1); + if gate0 == K::ZERO && gate1 == K::ZERO { + continue; + } + + let x0 = self.x.get(child0); + let x1 = self.x.get(child1); + + let mut expr0 = x0; + let mut expr1 = x1; + for (b_col, w) in self.bits.iter().zip(self.weights.iter()) { + let b0 = b_col.get(child0); + let b1 = b_col.get(child1); + expr0 -= b0 * *w; + expr1 -= b1 * *w; + } + + let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); + + for (i, &x) in points.iter().enumerate() { + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let expr_x = interp(expr0, expr1, x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; + } + } + ys + } + + fn num_rounds(&self) -> usize { + self.r_cycle.len().saturating_sub(self.bit_idx) + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, r: K) { + if self.num_rounds() == 0 { + return; + } + self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); + self.has_lookup.fold_round_in_place(r); + self.x.fold_round_in_place(r); + for b in self.bits.iter_mut() { + b.fold_round_in_place(r); + } + self.bit_idx += 1; + } +} + +/// Zero oracle over the time hypercube (for placeholder claims). +pub struct ZeroOracleSparseTime { + rounds_remaining: usize, + degree_bound: usize, +} + +impl ZeroOracleSparseTime { + pub fn new(num_rounds: usize, degree_bound: usize) -> Self { + Self { + rounds_remaining: num_rounds, + degree_bound, + } + } +} + +impl RoundOracle for ZeroOracleSparseTime { + fn evals_at(&mut self, points: &[K]) -> Vec { + vec![K::ZERO; points.len()] + } + + fn num_rounds(&self) -> usize { + self.rounds_remaining + } + + fn degree_bound(&self) -> usize { + self.degree_bound + } + + fn fold(&mut self, _r: K) { + if self.rounds_remaining > 0 { + self.rounds_remaining -= 1; + } + } +} + #[inline] fn interp(f0: K, f1: K, x: K) -> K { f0 + (f1 - f0) * x diff --git a/crates/neo-memory/src/witness.rs b/crates/neo-memory/src/witness.rs index 7fa61ba0..d0b416b1 100644 --- a/crates/neo-memory/src/witness.rs +++ b/crates/neo-memory/src/witness.rs @@ -24,6 +24,51 @@ pub enum LutTableSpec { /// - Address bits are little-endian and correspond to `interleave_bits(rs1, rs2)`. RiscvOpcode { opcode: RiscvOpcode, xlen: usize }, + /// A packed-key (non-bit-addressed) variant of `RiscvOpcode`, intended for "no width bloat" + /// Shout/ALU proofs that do **not** commit to `ell_addr=64` addr-bit columns. + /// + /// Current implementation status: + /// - Supported: `xlen = 32` for selected RV32 ops, including: + /// - bitwise: `And | Andn | Or | Xor` + /// - arithmetic: `Add | Sub` + /// - compares: `Eq | Neq | Slt | Sltu` + /// - shifts: `Sll | Srl | Sra` + /// - RV32M: `Mul | Mulh | Mulhu | Mulhsu | Div | Divu | Rem | Remu` + /// - Witness convention: the Shout lane's `addr_bits` slice is repurposed as packed columns. + /// The exact layout depends on `opcode`; the suffix columns are always `[has_lookup, val_u32]`. + /// Examples: + /// - `Add/Sub` (d=3): `[lhs_u32, rhs_u32, aux_bit]` (carry for `Add`, borrow for `Sub`) + /// - `Eq/Neq` (d=35): `[lhs_u32, rhs_u32, borrow_bit, diff_bits[0..32]]` where `val_u32` is the out bit + /// - `Mul` (d=34): `[lhs_u32, rhs_u32, hi_bits[0..32]]` where `val_u32` is the low 32 bits + /// - `Mulhu` (d=34): `[lhs_u32, rhs_u32, lo_bits[0..32]]` where `val_u32` is the high 32 bits + /// - `Sltu` (d=35): `[lhs_u32, rhs_u32, diff_u32, diff_bits[0..32]]` where `val_u32` is the out bit + /// - `Sll/Srl/Sra` (d=38): `[lhs_u32, shamt_bits[0..5], ...]` + /// + /// For packed-key instances, Route-A enforces correctness directly via time-domain constraints + /// (claimed sum forced to 0); table MLE evaluation is not used. + RiscvOpcodePacked { opcode: RiscvOpcode, xlen: usize }, + + /// An "event table" packed-key variant of `RiscvOpcodePacked` for RV32. + /// + /// Instead of storing one Shout lane over time (one row per cycle), the witness stores only + /// the executed lookup events (one row per lookup). Each event row carries: + /// - a prefix of `time_bits` boolean columns encoding the original Route-A time index `t` + /// (little-endian), and + /// - the same packed columns as `RiscvOpcodePacked` for `opcode`. + /// + /// The Route-A protocol then links the event table back to the CPU trace via a "scatter" + /// check at a random time point `r_cycle` (Jolt-ish): roughly, + /// Σ_events hash(event)·χ_{r_cycle}(t_event) == Σ_t hash(trace[t])·χ_{r_cycle}(t). + /// + /// Notes: + /// - This is RV32-only (`xlen = 32`), `n_side = 2`, `ell = 1`. + /// - `time_bits` must match the Route-A `ell_n` used for the time domain. + RiscvOpcodeEventTablePacked { + opcode: RiscvOpcode, + xlen: usize, + time_bits: usize, + }, + /// Implicit identity table over 32-bit addresses: `table[addr] = addr`. /// /// Addressing convention: @@ -39,6 +84,12 @@ impl LutTableSpec { LutTableSpec::RiscvOpcode { opcode, xlen } => { Ok(crate::riscv::lookups::evaluate_opcode_mle(*opcode, r_addr, *xlen)) } + LutTableSpec::RiscvOpcodePacked { .. } => Err(PiCcsError::InvalidInput( + "RiscvOpcodePacked does not support eval_table_mle (not bit-addressed)".into(), + )), + LutTableSpec::RiscvOpcodeEventTablePacked { .. } => Err(PiCcsError::InvalidInput( + "RiscvOpcodeEventTablePacked does not support eval_table_mle (not bit-addressed)".into(), + )), LutTableSpec::IdentityU32 => { if r_addr.len() != 32 { return Err(PiCcsError::InvalidInput(format!( @@ -54,6 +105,11 @@ impl LutTableSpec { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MemInstance { + /// Logical memory instance identifier (e.g. RISC-V `PROG_ID/REG_ID/RAM_ID`). + /// + /// This is used by higher-level protocols to link Twist instances to CPU trace columns + /// without relying on a fixed instance ordering. + pub mem_id: u32, pub comms: Vec, pub k: usize, pub d: usize, @@ -129,6 +185,8 @@ impl MemInstance { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct LutInstance { + /// Logical shout table identifier. + pub table_id: u32, pub comms: Vec, pub k: usize, pub d: usize, @@ -146,6 +204,12 @@ pub struct LutInstance { #[serde(default)] pub table_spec: Option, pub table: Vec, + /// Optional address-sharing group id for shared-bus column layout. + #[serde(default)] + pub addr_group: Option, + /// Optional selector-sharing group id for shared-bus column layout. + #[serde(default)] + pub selector_group: Option, } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/crates/neo-memory/tests/cpu_bus_multi_instance_injection.rs b/crates/neo-memory/tests/cpu_bus_multi_instance_injection.rs index b4bc70bf..b9822117 100644 --- a/crates/neo-memory/tests/cpu_bus_multi_instance_injection.rs +++ b/crates/neo-memory/tests/cpu_bus_multi_instance_injection.rs @@ -34,8 +34,9 @@ fn empty_identity_first_r1cs_ccs(n: usize) -> CcsStructure { CcsStructure::new(vec![i_n, a, b, c], f).expect("CCS") } -fn lut_inst() -> LutInstance<(), F> { +fn lut_inst(table_id: u32) -> LutInstance<(), F> { LutInstance { + table_id, comms: Vec::new(), k: 2, d: 1, @@ -45,11 +46,14 @@ fn lut_inst() -> LutInstance<(), F> { ell: 1, table_spec: None, table: vec![F::ZERO, F::ONE], + addr_group: None, + selector_group: None, } } -fn mem_inst() -> MemInstance<(), F> { +fn mem_inst(mem_id: u32) -> MemInstance<(), F> { MemInstance { + mem_id, comms: Vec::new(), k: 2, d: 1, @@ -67,8 +71,10 @@ fn shared_cpu_bus_injection_supports_independent_instances() { let n = 64usize; let base_ccs = empty_identity_first_r1cs_ccs(n); - let lut_insts = vec![lut_inst(), lut_inst()]; - let mem_insts = vec![mem_inst(), mem_inst()]; + // Use table IDs outside RV32 shared-address groups so each instance has an independent + // `[addr_bits, has_lookup, val]` bus slice in this regression. + let lut_insts = vec![lut_inst(1000), lut_inst(1001)]; + let mem_insts = vec![mem_inst(100), mem_inst(101)]; // CPU columns (all < bus_base) are per-instance. let shout_cpu = vec![ @@ -124,11 +130,11 @@ fn shared_cpu_bus_injection_supports_independent_instances() { // CPU witness: make only shout0 and twist1 active. z[1] = F::ONE; // shout0.has_lookup z[2] = F::ONE; // shout0.addr (packed) - z[3] = F::from_u64(7); // shout0.val + z[3] = F::from_u64(7); // shout0.primary_val() z[4] = F::ZERO; // shout1.has_lookup z[5] = F::ZERO; // shout1.addr - z[6] = F::ZERO; // shout1.val + z[6] = F::ZERO; // shout1.primary_val() z[7] = F::ZERO; // twist0.has_read z[8] = F::ZERO; // twist0.has_write diff --git a/crates/neo-memory/tests/cpu_constraints_tests.rs b/crates/neo-memory/tests/cpu_constraints_tests.rs index e2919e43..82ba0740 100644 --- a/crates/neo-memory/tests/cpu_constraints_tests.rs +++ b/crates/neo-memory/tests/cpu_constraints_tests.rs @@ -38,7 +38,7 @@ fn test_shout_bus_config() { let cfg = &bus.shout_cols[0].lanes[0]; assert_eq!(cfg.addr_bits, 0..4); assert_eq!(cfg.has_lookup, 4); - assert_eq!(cfg.val, 5); + assert_eq!(cfg.primary_val(), 5); assert_eq!(bus.bus_cols, 6); } diff --git a/crates/neo-memory/tests/r1cs_cpu_shared_bus_no_footguns.rs b/crates/neo-memory/tests/r1cs_cpu_shared_bus_no_footguns.rs index 2cddc584..6f28d230 100644 --- a/crates/neo-memory/tests/r1cs_cpu_shared_bus_no_footguns.rs +++ b/crates/neo-memory/tests/r1cs_cpu_shared_bus_no_footguns.rs @@ -107,7 +107,8 @@ fn with_shared_cpu_bus_injects_constraints_and_forces_const_one() { &tables, &HashMap::new(), Box::new(|_step| vec![F::ZERO]), - ); + ) + .expect("R1csCpu::new"); let mut mem_layouts: HashMap = HashMap::new(); mem_layouts.insert( @@ -144,6 +145,8 @@ fn with_shared_cpu_bus_injects_constraints_and_forces_const_one() { inc: None, }], )]), + shout_addr_groups: HashMap::new(), + shout_selector_groups: HashMap::new(), }; let cpu = cpu @@ -171,6 +174,85 @@ fn with_shared_cpu_bus_injects_constraints_and_forces_const_one() { check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("satisfiable"); } +#[test] +fn with_shared_cpu_bus_accepts_empty_shout_bindings_for_padding_only_mode() { + let n = 64usize; + let ccs = empty_identity_first_r1cs_ccs(n); + let params = NeoParams::goldilocks_auto_r1cs_ccs(n).expect("params"); + + let mut tables: HashMap> = HashMap::new(); + tables.insert( + 1, + LutTable { + table_id: 1, + k: 2, + d: 1, + n_side: 2, + content: vec![F::ZERO, F::ONE], + }, + ); + + let cpu = R1csCpu::new( + ccs.clone(), + params, + NoopCommit::default(), + /*m_in=*/ 1, + &tables, + &HashMap::new(), + Box::new(|_step| vec![F::ZERO]), + ) + .expect("R1csCpu::new"); + + let cfg = SharedCpuBusConfig:: { + mem_layouts: HashMap::new(), + initial_mem: HashMap::new(), + const_one_col: 0, + // Empty binding vector => padding/bitness-only shout lane constraints. + shout_cpu: HashMap::from([(1, Vec::::new())]), + twist_cpu: HashMap::new(), + shout_addr_groups: HashMap::new(), + shout_selector_groups: HashMap::new(), + }; + + let cpu = cpu + .with_shared_cpu_bus(cfg, /*chunk_size=*/ 1) + .expect("enable shared_cpu_bus with empty shout bindings"); + + assert!( + ccs_matrix_has_any_nonzero(&cpu.ccs.matrices[1]), + "expected injected shout padding/bitness constraints in A matrix" + ); + assert!( + ccs_matrix_has_any_nonzero(&cpu.ccs.matrices[2]), + "expected injected shout padding/bitness constraints in B matrix" + ); + + let trace = VmTrace { + steps: vec![StepTrace { + cycle: 0, + pc_before: 0, + pc_after: 0, + opcode: 0, + regs_before: Vec::new(), + regs_after: Vec::new(), + twist_events: Vec::new(), + shout_events: vec![ShoutEvent { + shout_id: ShoutId(1), + key: 1, + value: 7, + }], + halted: false, + }], + }; + + let mcss = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build ccs steps"); + assert_eq!(mcss.len(), 1); + let (mcs_inst, mcs_wit) = &mcss[0]; + assert_eq!(mcs_inst.x.len(), 1); + assert_eq!(mcs_inst.x[0], F::ONE, "const_one_col should be forced to 1"); + check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("satisfiable"); +} + #[test] fn shared_bus_shout_lane_assignment_is_in_order_and_resets_per_step() { let n = 128usize; @@ -238,7 +320,8 @@ fn shared_bus_shout_lane_assignment_is_in_order_and_resets_per_step() { } z }), - ); + ) + .expect("R1csCpu::new"); let cfg = SharedCpuBusConfig:: { mem_layouts: HashMap::new(), @@ -246,6 +329,8 @@ fn shared_bus_shout_lane_assignment_is_in_order_and_resets_per_step() { const_one_col: 0, shout_cpu: HashMap::from([(1, vec![shout_lane0_cfg, shout_lane1_cfg])]), twist_cpu: HashMap::new(), + shout_addr_groups: HashMap::new(), + shout_selector_groups: HashMap::new(), }; let cpu = cpu @@ -325,20 +410,20 @@ fn shared_bus_shout_lane_assignment_is_in_order_and_resets_per_step() { // Step j=0: lane0=(key=0,val=5), lane1=(key=1,val=7). assert_eq!(z[layout.bus_cell(lane0.has_lookup, 0)], F::ONE); assert_eq!(z[layout.bus_cell(lane0.addr_bits.start, 0)], F::ZERO); - assert_eq!(z[layout.bus_cell(lane0.val, 0)], F::from_u64(5)); + assert_eq!(z[layout.bus_cell(lane0.primary_val(), 0)], F::from_u64(5)); assert_eq!(z[layout.bus_cell(lane1.has_lookup, 0)], F::ONE); assert_eq!(z[layout.bus_cell(lane1.addr_bits.start, 0)], F::ONE); - assert_eq!(z[layout.bus_cell(lane1.val, 0)], F::from_u64(7)); + assert_eq!(z[layout.bus_cell(lane1.primary_val(), 0)], F::from_u64(7)); // Step j=1: lane0=(key=1,val=9), lane1=(key=0,val=11). assert_eq!(z[layout.bus_cell(lane0.has_lookup, 1)], F::ONE); assert_eq!(z[layout.bus_cell(lane0.addr_bits.start, 1)], F::ONE); - assert_eq!(z[layout.bus_cell(lane0.val, 1)], F::from_u64(9)); + assert_eq!(z[layout.bus_cell(lane0.primary_val(), 1)], F::from_u64(9)); assert_eq!(z[layout.bus_cell(lane1.has_lookup, 1)], F::ONE); assert_eq!(z[layout.bus_cell(lane1.addr_bits.start, 1)], F::ZERO); - assert_eq!(z[layout.bus_cell(lane1.val, 1)], F::from_u64(11)); + assert_eq!(z[layout.bus_cell(lane1.primary_val(), 1)], F::from_u64(11)); // The injected constraints should be satisfiable for the constructed witness. check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("satisfiable"); @@ -370,7 +455,8 @@ fn shared_bus_rejects_shout_lane_overflow_in_one_step() { &tables, &HashMap::new(), Box::new(|_chunk| vec![F::ZERO]), - ); + ) + .expect("R1csCpu::new"); let cfg = SharedCpuBusConfig:: { mem_layouts: HashMap::new(), @@ -385,6 +471,8 @@ fn shared_bus_rejects_shout_lane_overflow_in_one_step() { }], )]), twist_cpu: HashMap::new(), + shout_addr_groups: HashMap::new(), + shout_selector_groups: HashMap::new(), }; let cpu = cpu @@ -435,7 +523,8 @@ fn with_shared_cpu_bus_rejects_non_public_const_one() { &tables, &HashMap::new(), Box::new(|_step| vec![F::ZERO]), - ); + ) + .expect("R1csCpu::new"); let cfg = SharedCpuBusConfig:: { mem_layouts: HashMap::new(), @@ -443,6 +532,8 @@ fn with_shared_cpu_bus_rejects_non_public_const_one() { const_one_col: 1, // not < m_in shout_cpu: HashMap::new(), twist_cpu: HashMap::new(), + shout_addr_groups: HashMap::new(), + shout_selector_groups: HashMap::new(), }; assert!( @@ -477,7 +568,8 @@ fn with_shared_cpu_bus_rejects_bindings_in_bus_tail() { &tables, &HashMap::new(), Box::new(|_step| vec![F::ZERO]), - ); + ) + .expect("R1csCpu::new"); let mut mem_layouts: HashMap = HashMap::new(); mem_layouts.insert( @@ -515,6 +607,8 @@ fn with_shared_cpu_bus_rejects_bindings_in_bus_tail() { inc: None, }], )]), + shout_addr_groups: HashMap::new(), + shout_selector_groups: HashMap::new(), }; assert!( diff --git a/crates/neo-memory/tests/riscv_ccs_tests.rs b/crates/neo-memory/tests/riscv_ccs_tests.rs index 43e11ded..18dd7003 100644 --- a/crates/neo-memory/tests/riscv_ccs_tests.rs +++ b/crates/neo-memory/tests/riscv_ccs_tests.rs @@ -1,4648 +1,293 @@ -//! Tests for the RV32 B1 shared-bus step CCS. - use std::collections::HashMap; -use neo_ccs::matrix::Mat; use neo_ccs::relations::check_ccs_rowwise_zero; -use neo_ccs::traits::SModuleHomomorphism; +use neo_memory::cpu::CPU_BUS_COL_DISABLED; use neo_memory::plain::PlainMemLayout; use neo_memory::riscv::ccs::{ - build_rv32_b1_step_ccs, rv32_b1_chunk_to_full_witness_checked, rv32_b1_chunk_to_witness, - rv32_b1_shared_cpu_bus_config, + build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, rv32_trace_shared_bus_requirements_with_specs, + rv32_trace_shared_cpu_bus_config_with_specs, Rv32TraceCcsLayout, TraceShoutBusSpec, }; +use neo_memory::riscv::exec_table::Rv32ExecTable; use neo_memory::riscv::lookups::{ - decode_instruction, encode_program, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, - RiscvOpcode, RiscvShoutTables, JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID, + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, RiscvOpcode, RiscvShoutTables, + PROG_ID, RAM_ID, REG_ID, }; -use neo_memory::riscv::rom_init::prog_init_words; -use neo_memory::witness::LutTableSpec; -use neo_memory::{CpuArithmetization, R1csCpu}; -use neo_params::NeoParams; -use neo_vm_trace::{trace_program, StepTrace, TwistEvent, TwistOpKind, VmTrace}; +use neo_memory::riscv::trace::{rv32_decode_lookup_table_id_for_col, Rv32DecodeSidecarLayout}; +use neo_vm_trace::trace_program; use p3_field::PrimeCharacteristicRing; use p3_goldilocks::Goldilocks as F; -#[derive(Clone, Copy, Default)] -struct NoopCommit; - -impl SModuleHomomorphism for NoopCommit { - fn commit(&self, _z: &Mat) -> () {} - - fn project_x(&self, z: &Mat, m_in: usize) -> Mat { - let rows = z.rows(); - let mut out = Mat::zero(rows, m_in, F::ZERO); - for r in 0..rows { - for c in 0..m_in.min(z.cols()) { - out[(r, c)] = z[(r, c)]; - } - } - out - } -} - -fn pow2_ceil_k(min_k: usize) -> (usize, usize) { - let k = min_k.next_power_of_two().max(2); - let d = k.trailing_zeros() as usize; - (k, d) -} - -fn load_u32_imm(rd: u8, value: u32) -> Vec { - let upper = ((value as i64 + 0x800) >> 12) as i32; - let lower = (value as i32) - (upper << 12); - vec![ - RiscvInstruction::Lui { rd, imm: upper }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd, - rs1: rd, - imm: lower, - }, - ] -} - -const RV32I_SHOUT_TABLE_IDS: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; - -fn rv32i_table_specs(xlen: usize) -> HashMap { +fn sample_mem_layouts() -> HashMap { HashMap::from([ ( - 0u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::And, - xlen, - }, - ), - ( - 1u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Xor, - xlen, - }, - ), - ( - 2u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Or, - xlen, - }, - ), - ( - 3u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Add, - xlen, - }, - ), - ( - 4u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Sub, - xlen, - }, - ), - ( - 5u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Slt, - xlen, - }, - ), - ( - 6u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Sltu, - xlen, - }, - ), - ( - 7u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Sll, - xlen, - }, - ), - ( - 8u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Srl, - xlen, - }, - ), - ( - 9u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Sra, - xlen, - }, - ), - ( - 10u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Eq, - xlen, - }, - ), - ( - 11u32, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Neq, - xlen, - }, - ), - ]) -} - -#[test] -fn rv32_b1_ccs_happy_path_small_program() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::Lui { rd: 1, imm: 1 }, // x1 = 0x1000 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 1, - imm: 5, - }, // x1 = 0x1005 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 7, - }, // x2 = 7 - RiscvInstruction::RAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 1, - rs2: 2, - }, // x3 = 0x100c - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 0, - rs2: 3, - imm: 0x100, - }, // mem[0x100] = x3 - RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 4, - rs1: 0, - imm: 0x100, - }, // x4 = mem[0x100] - RiscvInstruction::Auipc { rd: 5, imm: 0 }, // x5 = pc - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - // mem_layouts: keep k small to reduce bus tail width. - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); // covers addresses up to 0x1ff - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, + PROG_ID.0, PlainMemLayout { - k: k_prog, - d: d_prog, + k: 16, + d: 4, n_side: 2, lanes: 1, }, ), - ]); - - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_happy_path_rv32i_fence_program() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::Fence { pred: 0, succ: 0 }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 1, - imm: 2, - }, // x2 = 3 - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let regs = &trace.steps.last().expect("steps").regs_after; - assert_eq!(regs[1], 1, "ADD before FENCE"); - assert_eq!(regs[2], 3, "ADD after FENCE"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x40); - let mem_layouts = HashMap::from([ ( - 0u32, + REG_ID.0, PlainMemLayout { - k: k_ram, - d: d_ram, + k: 32, + d: 5, n_side: 2, - lanes: 1, + lanes: 2, }, ), ( - 1u32, + RAM_ID.0, PlainMemLayout { - k: k_prog, - d: d_prog, + k: 16, + d: 4, n_side: 2, lanes: 1, }, ), - ]); - - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } + ]) } -#[test] -fn rv32_b1_ccs_happy_path_rv32i_ecall_markers_program() { - let xlen = 32usize; - let mut program = Vec::new(); - program.extend(load_u32_imm(10, JOLT_CYCLE_TRACK_ECALL_NUM)); - program.push(RiscvInstruction::Halt); - program.push(RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 7, - }); - program.extend(load_u32_imm(10, JOLT_PRINT_ECALL_NUM)); - program.push(RiscvInstruction::Halt); - program.push(RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 1, - imm: 1, - }); - program.push(RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 10, - rs1: 0, - imm: 0, - }); - program.push(RiscvInstruction::Halt); +fn decode_selector_specs(prog_d: usize) -> Vec { + let decode = Rv32DecodeSidecarLayout::new(); + [decode.rd_has_write, decode.ram_has_read, decode.ram_has_write] + .into_iter() + .map(|col| TraceShoutBusSpec { + table_id: rv32_decode_lookup_table_id_for_col(col), + ell_addr: prog_d, + n_vals: 1usize, + }) + .collect() +} + +fn full_rv32i_table_ids() -> Vec { + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + [ + RiscvOpcode::And, + RiscvOpcode::Xor, + RiscvOpcode::Or, + RiscvOpcode::Add, + RiscvOpcode::Sub, + RiscvOpcode::Slt, + RiscvOpcode::Sltu, + RiscvOpcode::Sll, + RiscvOpcode::Srl, + RiscvOpcode::Sra, + RiscvOpcode::Eq, + RiscvOpcode::Neq, + ] + .into_iter() + .map(|op| shout.opcode_to_id(op).0) + .collect() +} +fn exec_table_for(program: Vec, min_len: usize, max_steps: usize) -> Rv32ExecTable { let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 128).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let regs = &trace.steps.last().expect("steps").regs_after; - assert_eq!(regs[1], 7, "instruction after ECALL marker executes"); - assert_eq!(regs[2], 8, "instruction after ECALL print executes"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x40); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); + let decoded_program = decode_program(&program_bytes).expect("decode_program"); - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, max_steps).expect("trace_program"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, min_len).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + exec } #[test] -fn rv32_b1_ccs_happy_path_rv32m_program() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: -6, - }, // x1 = -6 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 3, - }, // x2 = 3 - RiscvInstruction::RAlu { - op: RiscvOpcode::Mul, - rd: 3, - rs1: 1, - rs2: 2, - }, - RiscvInstruction::RAlu { - op: RiscvOpcode::Divu, - rd: 4, - rs1: 3, - rs2: 2, - }, - RiscvInstruction::RAlu { - op: RiscvOpcode::Remu, - rd: 5, - rs1: 1, - rs2: 2, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x40); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - // Minimal table set for this program: - // - ADD (address/ALU wiring + ADDI), - // - SLTU (DIVU/REMU remainder bound check). - let shout_tables = RiscvShoutTables::new(xlen); - let add_id = shout_tables.opcode_to_id(RiscvOpcode::Add).0; - let sltu_id = shout_tables.opcode_to_id(RiscvOpcode::Sltu).0; - let shout_table_ids: [u32; 2] = [add_id, sltu_id]; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - - let table_specs = HashMap::from([ - ( - add_id, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Add, - xlen, - }, - ), - ( - sltu_id, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Sltu, - xlen, +fn rv32_trace_ccs_happy_path_addi_halt() { + let exec = exec_table_for( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, }, - ), - ]); + RiscvInstruction::Halt, + ], + /*min_len=*/ 4, + /*max_steps=*/ 16, + ); - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } + check_ccs_rowwise_zero(&ccs, &x, &w).expect("trace CCS satisfied"); } #[test] -fn rv32_b1_ccs_happy_path_rv32m_signed_program() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 0, - }, // x1 = 0 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: -3, - }, // x2 = -3 - RiscvInstruction::RAlu { - op: RiscvOpcode::Div, - rd: 3, - rs1: 1, - rs2: 2, - }, // x3 = 0 / -3 - RiscvInstruction::RAlu { - op: RiscvOpcode::Rem, - rd: 4, - rs1: 1, - rs2: 2, - }, // x4 = 0 % -3 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 5, - rs1: 0, - imm: -4, - }, // x5 = -4 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 6, - rs1: 0, - imm: 2, - }, // x6 = 2 - RiscvInstruction::RAlu { - op: RiscvOpcode::Rem, - rd: 7, - rs1: 5, - rs2: 6, - }, // x7 = -4 % 2 - RiscvInstruction::RAlu { - op: RiscvOpcode::Div, - rd: 8, - rs1: 5, - rs2: 6, - }, // x8 = -4 / 2 - RiscvInstruction::RAlu { - op: RiscvOpcode::Mulh, - rd: 9, - rs1: 5, - rs2: 6, - }, - RiscvInstruction::RAlu { - op: RiscvOpcode::Mulhsu, - rd: 10, - rs1: 5, - rs2: 6, - }, - RiscvInstruction::RAlu { - op: RiscvOpcode::Mulhu, - rd: 11, - rs1: 5, - rs2: 6, - }, - RiscvInstruction::RAlu { - op: RiscvOpcode::Div, - rd: 12, - rs1: 5, - rs2: 0, - }, // div by zero - RiscvInstruction::RAlu { - op: RiscvOpcode::Rem, - rd: 13, - rs1: 5, - rs2: 0, - }, // rem by zero - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 128).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let regs = &trace.steps.last().expect("steps").regs_after; - assert_eq!(regs[3], 0, "DIV 0 / -3"); - assert_eq!(regs[4], 0, "REM 0 % -3"); - assert_eq!(regs[7], 0, "REM -4 % 2"); - assert_eq!(regs[8], 0xffff_fffe, "DIV -4 / 2"); - assert_eq!(regs[9], 0xffff_ffff, "MULH -4 * 2"); - assert_eq!(regs[10], 0xffff_ffff, "MULHSU -4 * 2"); - assert_eq!(regs[11], 0x0000_0001, "MULHU 0xffff_fffc * 2"); - assert_eq!(regs[12], 0xffff_ffff, "DIV by zero returns -1"); - assert_eq!(regs[13], 0xffff_fffc, "REM by zero returns dividend"); - - let shout_tables = RiscvShoutTables::new(xlen); - let sltu_id = shout_tables.opcode_to_id(RiscvOpcode::Sltu).0; - for &idx in &[2usize, 3, 6, 7] { - let events = &trace.steps[idx].shout_events; - assert_eq!(events.len(), 1, "expected SLTU lookup at step {idx}"); - assert_eq!(events[0].shout_id.0, sltu_id, "expected SLTU table id at step {idx}"); - } - for &idx in &[11usize, 12] { - assert!( - trace.steps[idx].shout_events.is_empty(), - "expected no lookup for div/rem by zero at step {idx}" - ); - } - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, +fn rv32_trace_ccs_happy_path_addi_sw_lw_halt() { + let exec = exec_table_for( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 7, }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let add_id = shout_tables.opcode_to_id(RiscvOpcode::Add).0; - let shout_table_ids: [u32; 2] = [add_id, sltu_id]; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - - let table_specs = HashMap::from([ - ( - add_id, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Add, - xlen, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 0, }, - ), - ( - sltu_id, - LutTableSpec::RiscvOpcode { - opcode: RiscvOpcode::Sltu, - xlen, + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 2, + rs1: 0, + imm: 0, }, - ), - ]); + RiscvInstruction::Halt, + ], + /*min_len=*/ 4, + /*max_steps=*/ 32, + ); - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } + check_ccs_rowwise_zero(&ccs, &x, &w).expect("trace CCS satisfied"); } #[test] -fn rv32_b1_witness_bus_alu_step() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 5, - }, - RiscvInstruction::Halt, - ]; - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - - let step = trace.steps.first().expect("step").clone(); - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(4); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, +fn rv32_trace_ccs_rejects_tampered_pc_transition() { + let exec = exec_table_for( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 1, }, - ), - ]); - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (_ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let z = rv32_b1_chunk_to_full_witness_checked(&layout, std::slice::from_ref(&step)).expect("witness"); + RiscvInstruction::Halt, + ], + /*min_len=*/ 4, + /*max_steps=*/ 16, + ); - let shout_tables = RiscvShoutTables::new(xlen); - let add_id = shout_tables.opcode_to_id(RiscvOpcode::Add).0; - let add_idx = layout.shout_idx(add_id).expect("add idx"); - let add_lane = &layout.bus.shout_cols[add_idx].lanes[0]; - let prog_lane = &layout.bus.twist_cols[layout.prog_twist_idx].lanes[0]; - let ram_lane = &layout.bus.twist_cols[layout.ram_twist_idx].lanes[0]; + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); - assert_eq!(z[layout.bus.bus_cell(prog_lane.has_read, 0)], F::ONE); - assert_eq!(z[layout.bus.bus_cell(prog_lane.has_write, 0)], F::ZERO); - assert_eq!(z[layout.bus.bus_cell(ram_lane.has_read, 0)], F::ZERO); - assert_eq!(z[layout.bus.bus_cell(ram_lane.has_write, 0)], F::ZERO); - assert_eq!(z[layout.bus.bus_cell(add_lane.has_lookup, 0)], F::ONE); + let idx = layout.cell(layout.trace.pc_before, 1) - layout.m_in; + w[idx] = w[idx] + F::ONE; - let shout_ev = step.shout_events.first().expect("shout event"); - assert_eq!(z[layout.bus.bus_cell(add_lane.val, 0)], F::from_u64(shout_ev.value)); - assert_eq!(z[layout.lookup_key(0)], F::ZERO); - assert_eq!(z[layout.alu_out(0)], F::from_u64(shout_ev.value)); - for (bit_idx, col_id) in add_lane.addr_bits.clone().enumerate() { - let bit = if bit_idx < 64 { (shout_ev.key >> bit_idx) & 1 } else { 0 }; - let expected = if bit == 1 { F::ONE } else { F::ZERO }; - assert_eq!(z[layout.bus.bus_cell(col_id, 0)], expected); - } + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered pc_before on row 1 must violate trace transition wiring" + ); } #[test] -fn rv32_b1_witness_bus_lw_step() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 42, - }, - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 0, - rs2: 2, - imm: 0, - }, - RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 3, - rs1: 0, - imm: 0, - }, - RiscvInstruction::Halt, - ]; - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - - let step = trace.steps.get(2).expect("lw step").clone(); - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(4); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, +fn rv32_trace_ccs_rejects_first_row_inactive() { + let exec = exec_table_for( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, }, - ), - ]); - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (_ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let z = rv32_b1_chunk_to_full_witness_checked(&layout, std::slice::from_ref(&step)).expect("witness"); - - let shout_tables = RiscvShoutTables::new(xlen); - let add_id = shout_tables.opcode_to_id(RiscvOpcode::Add).0; - let add_idx = layout.shout_idx(add_id).expect("add idx"); - let add_lane = &layout.bus.shout_cols[add_idx].lanes[0]; - let shout_ev = step.shout_events.first().expect("shout event"); + RiscvInstruction::Halt, + ], + /*min_len=*/ 4, + /*max_steps=*/ 16, + ); - let ram_lane = &layout.bus.twist_cols[layout.ram_twist_idx].lanes[0]; - let ram_read = step - .twist_events - .iter() - .find(|ev| ev.twist_id == RAM_ID && ev.kind == TwistOpKind::Read) - .expect("ram read"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); - assert_eq!(z[layout.bus.bus_cell(add_lane.has_lookup, 0)], F::ONE); - assert_eq!(z[layout.bus.bus_cell(add_lane.val, 0)], F::from_u64(shout_ev.value)); - // RV32 B1 no longer binds the raw 64-bit Shout key into a single field element. The authoritative - // witness data is in the ADD lane addr_bits. - assert_eq!(z[layout.lookup_key(0)], F::ZERO); - assert_eq!(z[layout.alu_out(0)], F::from_u64(shout_ev.value)); - assert_eq!(z[layout.bus.bus_cell(ram_lane.has_read, 0)], F::ONE); - assert_eq!(z[layout.bus.bus_cell(ram_lane.has_write, 0)], F::ZERO); - assert_eq!(z[layout.bus.bus_cell(ram_lane.rv, 0)], F::from_u64(ram_read.value)); - assert_eq!(z[layout.mem_rv(0)], F::from_u64(ram_read.value)); - assert_eq!(z[layout.eff_addr(0)], F::from_u64(ram_read.addr)); + let idx = layout.cell(layout.trace.active, 0) - layout.m_in; + w[idx] = F::ZERO; - for (bit_idx, col_id) in ram_lane.ra_bits.clone().enumerate() { - let bit = if bit_idx < 64 { - (ram_read.addr >> bit_idx) & 1 - } else { - 0 - }; - let expected = if bit == 1 { F::ONE } else { F::ZERO }; - assert_eq!(z[layout.bus.bus_cell(col_id, 0)], expected); - } + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "trace execution anchor requires active[0] == 1" + ); } #[test] -fn rv32_b1_witness_bus_amoaddw_step() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 5, - }, - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 0, - rs2: 2, - imm: 0, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 7, - }, - RiscvInstruction::Amo { - op: RiscvMemOp::AmoaddW, - rd: 3, - rs1: 0, - rs2: 2, - }, - RiscvInstruction::Halt, - ]; - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - - let step = trace.steps.get(3).expect("amoaddw step").clone(); - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(4); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (_ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let z = rv32_b1_chunk_to_full_witness_checked(&layout, std::slice::from_ref(&step)).expect("witness"); - - let shout_tables = RiscvShoutTables::new(xlen); - let add_id = shout_tables.opcode_to_id(RiscvOpcode::Add).0; - let add_idx = layout.shout_idx(add_id).expect("add idx"); - let add_lane = &layout.bus.shout_cols[add_idx].lanes[0]; - let shout_ev = step.shout_events.first().expect("shout event"); +fn rv32_trace_shared_bus_config_uses_padding_only_shout_bindings() { + let mut layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let table_ids = full_rv32i_table_ids(); - let ram_lane = &layout.bus.twist_cols[layout.ram_twist_idx].lanes[0]; - let ram_read = step - .twist_events - .iter() - .find(|ev| ev.twist_id == RAM_ID && ev.kind == TwistOpKind::Read) - .expect("ram read"); - let ram_write = step - .twist_events - .iter() - .find(|ev| ev.twist_id == RAM_ID && ev.kind == TwistOpKind::Write) - .expect("ram write"); + let (bus_region_len, reserved_rows) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &table_ids, &decode_specs, &mem_layouts) + .expect("trace shared bus requirements"); + layout.m += bus_region_len; - assert_eq!(z[layout.bus.bus_cell(add_lane.has_lookup, 0)], F::ONE); - assert_eq!(z[layout.bus.bus_cell(add_lane.val, 0)], F::from_u64(shout_ev.value)); - assert_eq!(z[layout.lookup_key(0)], F::ZERO); - assert_eq!(z[layout.alu_out(0)], F::from_u64(shout_ev.value)); - for (bit_idx, col_id) in add_lane.addr_bits.clone().enumerate() { - let bit = if bit_idx < 64 { (shout_ev.key >> bit_idx) & 1 } else { 0 }; - let expected = if bit == 1 { F::ONE } else { F::ZERO }; - assert_eq!(z[layout.bus.bus_cell(col_id, 0)], expected); - } - assert_eq!(z[layout.bus.bus_cell(ram_lane.has_read, 0)], F::ONE); - assert_eq!(z[layout.bus.bus_cell(ram_lane.has_write, 0)], F::ONE); - assert_eq!(z[layout.bus.bus_cell(ram_lane.rv, 0)], F::from_u64(ram_read.value)); - assert_eq!(z[layout.bus.bus_cell(ram_lane.wv, 0)], F::from_u64(ram_write.value)); - assert_eq!(z[layout.mem_rv(0)], F::from_u64(ram_read.value)); - assert_eq!(z[layout.ram_wv(0)], F::from_u64(ram_write.value)); + let cfg = rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &table_ids, + &decode_specs, + mem_layouts, + HashMap::<(u32, u64), F>::new(), + ) + .expect("trace shared bus config"); - for (bit_idx, col_id) in ram_lane.ra_bits.clone().enumerate() { - let bit = if bit_idx < 64 { - (ram_read.addr >> bit_idx) & 1 - } else { - 0 - }; - let expected = if bit == 1 { F::ONE } else { F::ZERO }; - assert_eq!(z[layout.bus.bus_cell(col_id, 0)], expected); - } - for (bit_idx, col_id) in ram_lane.wa_bits.clone().enumerate() { - let bit = if bit_idx < 64 { - (ram_write.addr >> bit_idx) & 1 - } else { - 0 - }; - let expected = if bit == 1 { F::ONE } else { F::ZERO }; - assert_eq!(z[layout.bus.bus_cell(col_id, 0)], expected); + assert!(reserved_rows > 0, "expected reserved shared-bus constraints"); + for &table_id in &table_ids { + let lanes = cfg + .shout_cpu + .get(&table_id) + .expect("missing shout_cpu entry for table"); + assert!( + lanes.is_empty(), + "trace shared bus uses padding-only shout bindings (table_id={table_id})" + ); } } #[test] -fn rv32_b1_ccs_happy_path_rv32i_byte_half_load_store_program() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 0x100, - }, // x1 = 0x100 - RiscvInstruction::Lui { rd: 2, imm: 0x11223 }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Or, - rd: 2, - rs1: 2, - imm: 0x344, - }, // x2 = 0x11223344 - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 1, - rs2: 2, - imm: 0, - }, // mem[0x100] = 0x11223344 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 6, - rs1: 0, - imm: 0xAA, - }, // x6 = 0xAA - RiscvInstruction::Store { - op: RiscvMemOp::Sb, - rs1: 1, - rs2: 6, - imm: 1, - }, // mem[0x101] = 0xAA - RiscvInstruction::Load { - op: RiscvMemOp::Lb, - rd: 7, - rs1: 1, - imm: 1, - }, // x7 = signext(0xAA) - RiscvInstruction::Load { - op: RiscvMemOp::Lbu, - rd: 8, - rs1: 1, - imm: 1, - }, // x8 = 0xAA - RiscvInstruction::Lui { rd: 9, imm: 0x8 }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 9, - rs1: 9, - imm: 1, - }, // x9 = 0x8001 - RiscvInstruction::Store { - op: RiscvMemOp::Sh, - rs1: 1, - rs2: 9, - imm: 0, - }, // mem[0x100..] = 0x8001 - RiscvInstruction::Load { - op: RiscvMemOp::Lh, - rd: 10, - rs1: 1, - imm: 0, - }, // x10 = signext(0x8001) - RiscvInstruction::Load { - op: RiscvMemOp::Lhu, - rd: 11, - rs1: 1, - imm: 0, - }, // x11 = 0x8001 - RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 12, - rs1: 1, - imm: 0, - }, // x12 = 0x11228001 - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 128).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let regs = &trace.steps.last().expect("steps").regs_after; - assert_eq!(regs[7], 0xffff_ffaa, "LB sign-extends 0xAA"); - assert_eq!(regs[8], 0x0000_00aa, "LBU zero-extends 0xAA"); - assert_eq!(regs[10], 0xffff_8001, "LH sign-extends 0x8001"); - assert_eq!(regs[11], 0x0000_8001, "LHU zero-extends 0x8001"); - assert_eq!(regs[12], 0x1122_8001, "LW reads merged word"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); +fn rv32_trace_shared_bus_decode_lookup_binds_to_pc_before() { + let mut layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let table_ids = full_rv32i_table_ids(); - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); + let (bus_region_len, _) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &table_ids, &decode_specs, &mem_layouts) + .expect("trace shared bus requirements"); + layout.m += bus_region_len; - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, + let cfg = rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &table_ids, + &decode_specs, + mem_layouts, + HashMap::<(u32, u64), F>::new(), ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_byte_store_updates_aligned_word() { - let xlen = 32usize; - let mut program = Vec::new(); - program.push(RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 0x100, - }); // x1 = 0x100 - program.extend(load_u32_imm(2, 0x1122_3344)); // x2 = 0x11223344 - program.push(RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 1, - rs2: 2, - imm: 0, - }); // mem[0x100] = 0x11223344 - program.push(RiscvInstruction::Load { - op: RiscvMemOp::Lb, - rd: 3, - rs1: 1, - imm: 1, - }); // x3 = mem[0x101] = 0x33 - program.push(RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 4, - rs1: 0, - imm: 0xAA, - }); // x4 = 0xAA - program.push(RiscvInstruction::Store { - op: RiscvMemOp::Sb, - rs1: 1, - rs2: 4, - imm: 1, - }); // mem[0x101] = 0xAA - program.push(RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 5, - rs1: 1, - imm: 0, - }); // x5 = 0x1122AA44 - program.push(RiscvInstruction::Halt); - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 128).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); + .expect("trace shared bus config"); - let regs = &trace.steps.last().expect("steps").regs_after; - assert_eq!(regs[3], 0x0000_0033, "LB reads byte from 0x11223344 at +1"); - assert_eq!(regs[5], 0x1122_aa44, "SB updates the aligned LW word"); + let decode = Rv32DecodeSidecarLayout::new(); + let table_id = rv32_decode_lookup_table_id_for_col(decode.rd_has_write); + let lanes = cfg + .shout_cpu + .get(&table_id) + .expect("missing decode shout_cpu entry"); - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_rejects_misaligned_lh() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::Load { - op: RiscvMemOp::Lh, - rd: 1, - rs1: 0, - imm: 0x101, // misaligned (addr % 2 != 0) - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mcs_wit) = steps.remove(0); - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "misaligned LH should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_misaligned_lw() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 1, - rs1: 0, - imm: 0x102, // misaligned (addr % 4 != 0) - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mcs_wit) = steps.remove(0); - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "misaligned LW should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_misaligned_sh() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::Store { - op: RiscvMemOp::Sh, - rs1: 0, - rs2: 0, - imm: 0x101, // misaligned (addr % 2 != 0) - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mcs_wit) = steps.remove(0); - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "misaligned SH should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_misaligned_sw() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 0, - rs2: 0, - imm: 0x102, // misaligned (addr % 4 != 0) - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mcs_wit) = steps.remove(0); - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "misaligned SW should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_happy_path_rv32a_amoaddw_program() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 0x100, - }, // x1 = 0x100 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 10, - }, // x2 = 10 - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 1, - rs2: 2, - imm: 0, - }, // mem[0x100] = 10 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 0, - imm: 5, - }, // x3 = 5 - RiscvInstruction::Amo { - op: RiscvMemOp::AmoaddW, - rd: 4, - rs1: 1, - rs2: 3, - }, // x4 = old, mem = old + 5 - RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 5, - rs1: 1, - imm: 0, - }, // x5 = new - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_rejects_tampered_ram_write_value_for_amoaddw() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 0x100, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 10, - }, - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 1, - rs2: 2, - imm: 0, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 0, - imm: 5, - }, - RiscvInstruction::Amo { - op: RiscvMemOp::AmoaddW, - rd: 4, - rs1: 1, - rs2: 3, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let amo_step_idx = 4usize; - let (mcs_inst, mut mcs_wit) = steps.remove(amo_step_idx); - - let ram_wv_w_idx = layout - .ram_wv - .checked_sub(layout.m_in) - .expect("ram_wv must be in private witness"); - mcs_wit.w[ram_wv_w_idx] += F::ONE; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "tampered RAM write value should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_happy_path_rv32a_word_amos_program() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 0x100, - }, // x1 = 0x100 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 0b1111, - }, // x2 = 15 - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 1, - rs2: 2, - imm: 0, - }, // mem[0x100] = 15 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 0, - imm: 0b1010, - }, // x3 = 10 - RiscvInstruction::Amo { - op: RiscvMemOp::AmoandW, - rd: 4, - rs1: 1, - rs2: 3, - }, // mem &= 10 - RiscvInstruction::Amo { - op: RiscvMemOp::AmoorW, - rd: 5, - rs1: 1, - rs2: 3, - }, // mem |= 10 - RiscvInstruction::Amo { - op: RiscvMemOp::AmoxorW, - rd: 6, - rs1: 1, - rs2: 3, - }, // mem ^= 10 - RiscvInstruction::Amo { - op: RiscvMemOp::AmoswapW, - rd: 7, - rs1: 1, - rs2: 2, - }, // mem = 15 - RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 8, - rs1: 1, - imm: 0, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_chunk_size_2_padding_carries_state() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 2, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let chunk_size = 2usize; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, chunk_size).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - chunk_size, - ) - .expect("shared bus"); - - let chunks = CpuArithmetization::build_ccs_chunks(&cpu, &trace, chunk_size).expect("build chunks"); - for (mcs_inst, mcs_wit) in chunks { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_branches_and_jal() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 1, - }, // x2 = 1 - RiscvInstruction::Branch { - cond: BranchCondition::Eq, - rs1: 1, - rs2: 2, - imm: 8, - }, // taken -> skip next instruction - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 0, - imm: 99, - }, // skipped - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 0, - imm: 5, - }, // x3 = 5 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 2, - }, // x2 = 2 - RiscvInstruction::Branch { - cond: BranchCondition::Ne, - rs1: 1, - rs2: 2, - imm: 8, - }, // taken -> skip next instruction - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 4, - rs1: 0, - imm: 77, - }, // skipped - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 4, - rs1: 0, - imm: 7, - }, // x4 = 7 - RiscvInstruction::Jal { rd: 5, imm: 8 }, // x5=pc+4, jump over next instruction - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 6, - rs1: 0, - imm: 123, - }, // skipped - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_rv32i_alu_ops() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 5, - }, // x1 = 5 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 3, - }, // x2 = 3 - RiscvInstruction::RAlu { - op: RiscvOpcode::Sub, - rd: 3, - rs1: 1, - rs2: 2, - }, // x3 = x1 - x2 - RiscvInstruction::RAlu { - op: RiscvOpcode::And, - rd: 4, - rs1: 1, - rs2: 2, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Or, - rd: 5, - rs1: 1, - imm: 0x0f, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Xor, - rd: 6, - rs1: 1, - imm: 0x0f, - }, - RiscvInstruction::RAlu { - op: RiscvOpcode::Sll, - rd: 7, - rs1: 1, - rs2: 2, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Sll, - rd: 8, - rs1: 1, - imm: 2, - }, - RiscvInstruction::RAlu { - op: RiscvOpcode::Srl, - rd: 9, - rs1: 1, - rs2: 2, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Sra, - rd: 10, - rs1: 1, - imm: 1, - }, - RiscvInstruction::RAlu { - op: RiscvOpcode::Slt, - rd: 11, - rs1: 2, - rs2: 1, - }, // (3 < 5) => 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Sltu, - rd: 12, - rs1: 2, - imm: 4, - }, // (3 < 4) => 1 - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_branches_blt_bge_bltu_bgeu() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 2, - }, // x2 = 2 - RiscvInstruction::Branch { - cond: BranchCondition::Lt, - rs1: 1, - rs2: 2, - imm: 8, - }, // taken -> skip next instruction - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 0, - imm: 111, - }, // skipped - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 0, - imm: 3, - }, // x3 = 3 - RiscvInstruction::Branch { - cond: BranchCondition::Ge, - rs1: 2, - rs2: 1, - imm: 8, - }, // taken -> skip next instruction - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 4, - rs1: 0, - imm: 222, - }, // skipped - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 4, - rs1: 0, - imm: 4, - }, // x4 = 4 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 5, - rs1: 0, - imm: -1, - }, // x5 = 0xffff_ffff - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 6, - rs1: 0, - imm: 1, - }, // x6 = 1 - RiscvInstruction::Branch { - cond: BranchCondition::Ltu, - rs1: 5, - rs2: 6, - imm: 8, - }, // not taken -> execute next instruction - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 7, - rs1: 0, - imm: 7, - }, // x7 = 7 - RiscvInstruction::Branch { - cond: BranchCondition::Geu, - rs1: 5, - rs2: 6, - imm: 8, - }, // taken -> skip next instruction - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 8, - rs1: 0, - imm: 888, - }, // skipped - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 8, - rs1: 0, - imm: 8, - }, // x8 = 8 - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 128).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_jalr_masks_lsb() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 13, - }, // x1 = 13 - RiscvInstruction::Jalr { rd: 0, rs1: 1, imm: 0 }, // pc = (13 + 0) & !1 = 12 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 7, - }, // skipped - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 9, - }, // executed - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - for (mcs_inst, mcs_wit) in steps { - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - } -} - -#[test] -fn rv32_b1_ccs_rejects_step_after_halt_within_chunk() { - let xlen = 32usize; - let program = vec![RiscvInstruction::Halt, RiscvInstruction::Halt]; - let program_bytes = encode_program(&program); - - let w0 = u32::from_le_bytes(program_bytes[0..4].try_into().expect("word0")); - let w1 = u32::from_le_bytes(program_bytes[4..8].try_into().expect("word1")); - - let regs = vec![0u64; 32]; - let steps: Vec> = vec![ - StepTrace { - cycle: 0, - pc_before: 0, - pc_after: 4, - opcode: w0, - regs_before: regs.clone(), - regs_after: regs.clone(), - twist_events: vec![TwistEvent { - twist_id: PROG_ID, - kind: TwistOpKind::Read, - addr: 0, - value: w0 as u64, - lane: None, - }], - shout_events: Vec::new(), - halted: true, - }, - StepTrace { - cycle: 1, - pc_before: 4, - pc_after: 8, - opcode: w1, - regs_before: regs.clone(), - regs_after: regs.clone(), - twist_events: vec![TwistEvent { - twist_id: PROG_ID, - kind: TwistOpKind::Read, - addr: 4, - value: w1 as u64, - lane: None, - }], - shout_events: Vec::new(), - halted: true, - }, - ]; - let trace = VmTrace { steps }; - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let chunk_size = 2usize; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, chunk_size).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - chunk_size, - ) - .expect("shared bus"); - - let mut chunks = CpuArithmetization::build_ccs_chunks(&cpu, &trace, chunk_size).expect("chunks"); - assert_eq!(chunks.len(), 1, "expected single chunk"); - let (mcs_inst, mcs_wit) = chunks.pop().expect("chunk"); - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "step after HALT should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_tampered_pc_out() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mut mcs_wit) = steps.remove(0); - - let pc_out_w_idx = layout - .pc_out - .checked_sub(layout.m_in) - .expect("pc_out must be in private witness"); - mcs_wit.w[pc_out_w_idx] += F::ONE; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "tampered witness should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_non_boolean_prog_read_addr_bit() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mut mcs_wit) = steps.remove(0); - - let prog_cols = &layout.bus.twist_cols[layout.prog_twist_idx].lanes[0]; - let bit_col_id = prog_cols.ra_bits.start + 2; // keep alignment bits [0,1] untouched - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("prog bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - let new_bit = F::from_u64(2); - mcs_wit.w[bit_w_idx] = new_bit; - - let delta = new_bit - old_bit; - let pc_in_w_idx = layout - .pc_in - .checked_sub(layout.m_in) - .expect("pc_in must be in private witness"); - let pc_out_w_idx = layout - .pc_out - .checked_sub(layout.m_in) - .expect("pc_out must be in private witness"); - mcs_wit.w[pc_in_w_idx] += delta * F::from_u64(1 << 2); - mcs_wit.w[pc_out_w_idx] += delta * F::from_u64(1 << 2); - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "non-boolean prog addr bit should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_non_boolean_shout_addr_bit() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 2, - }, // x2 = 2 - RiscvInstruction::RAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 1, - rs2: 2, - }, // x3 = x1 + x2 (Shout ADD active) - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let add_step_idx = 2usize; - let (mcs_inst, mut mcs_wit) = steps.remove(add_step_idx); - - let instr = decode_instruction(trace.steps[add_step_idx].opcode).expect("decode"); - let (rs1_idx, rs2_idx) = match instr { - RiscvInstruction::RAlu { - op: RiscvOpcode::Add, - rs1, - rs2, - .. - } => (rs1 as usize, rs2 as usize), - other => panic!("expected ADD at step {add_step_idx}, got {other:?}"), - }; - - // Flip one ADD shout key bit to a non-boolean value, and adjust CPU columns so that - // all *linear* bindings still hold. Bitness constraints should still reject. - let add_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Add).0; - let add_shout_idx = layout.shout_idx(add_id).expect("ADD shout idx"); - let shout_cols = &layout.bus.shout_cols[add_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // key bit 0 (rs1 bit 0) - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - let new_bit = F::from_u64(2); - mcs_wit.w[bit_w_idx] = new_bit; - let delta = new_bit - old_bit; - - // Update packed key (lookup_key) to match the mutated bits. - let lookup_key_w_idx = layout - .lookup_key(0) - .checked_sub(layout.m_in) - .expect("lookup_key must be in private witness"); - mcs_wit.w[lookup_key_w_idx] += delta; - - // Update rs1_val and corresponding register snapshots to match the mutated even-bit packing. - let rs1_val_w_idx = layout - .rs1_val(0) - .checked_sub(layout.m_in) - .expect("rs1_val must be in private witness"); - mcs_wit.w[rs1_val_w_idx] += delta; - - let reg1_in_w_idx = layout - .reg_in(rs1_idx, 0) - .checked_sub(layout.m_in) - .expect("reg_in must be in private witness"); - let reg1_out_w_idx = layout - .reg_out(rs1_idx, 0) - .checked_sub(layout.m_in) - .expect("reg_out must be in private witness"); - mcs_wit.w[reg1_in_w_idx] += delta; - mcs_wit.w[reg1_out_w_idx] += delta; - - // Keep rs2 snapshots consistent as well (defensive sanity); no bit change expected. - let reg2_in_w_idx = layout - .reg_in(rs2_idx, 0) - .checked_sub(layout.m_in) - .expect("reg_in must be in private witness"); - let reg2_out_w_idx = layout - .reg_out(rs2_idx, 0) - .checked_sub(layout.m_in) - .expect("reg_out must be in private witness"); - mcs_wit.w[reg2_out_w_idx] = mcs_wit.w[reg2_in_w_idx]; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "non-boolean shout addr bit should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_rom_value_mismatch() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mut mcs_wit) = steps.remove(0); - - let prog_cols = &layout.bus.twist_cols[layout.prog_twist_idx].lanes[0]; - let rv_z = layout.bus.bus_cell(prog_cols.rv, 0); - let rv_w_idx = rv_z.checked_sub(layout.m_in).expect("rv in witness"); - mcs_wit.w[rv_w_idx] += F::ONE; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "rom value mismatch should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_tampered_regfile() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mut mcs_wit) = steps.remove(0); - - // Tamper with a non-rd register output (x2) without updating reg_in. - let r = 2usize; - let reg_out_w_idx = layout - .reg_out(r, 0) - .checked_sub(layout.m_in) - .expect("reg_out in witness"); - mcs_wit.w[reg_out_w_idx] += F::ONE; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "tampered regfile should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_tampered_x0() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mut mcs_wit) = steps.remove(0); - - let x0_out_w_idx = layout - .reg_out(0, 0) - .checked_sub(layout.m_in) - .expect("x0 out in witness"); - mcs_wit.w[x0_out_w_idx] = F::ONE; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "tampered x0 should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_binds_public_initial_and_final_state() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 5, - }, // x1 = 5 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 7, - }, // x2 = 7 - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let chunk_size = 8usize; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, chunk_size).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - chunk_size, - ) - .expect("shared bus"); - - let mut chunks = CpuArithmetization::build_ccs_chunks(&cpu, &trace, chunk_size).expect("build chunks"); - assert_eq!(chunks.len(), 1, "chunk_size>N should create one chunk"); - let (mcs_inst, mcs_wit) = chunks.remove(0); - - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); - - let first = trace.steps.first().expect("trace non-empty"); - assert_eq!(mcs_inst.x[layout.pc0], F::from_u64(first.pc_before)); - for r in 0..32 { - assert_eq!(mcs_inst.x[layout.regs0_start + r], F::from_u64(first.regs_before[r])); - } - assert_eq!(mcs_inst.x[layout.regs0_start], F::ZERO); - - let last = trace.steps.last().expect("trace non-empty"); - assert_eq!(mcs_inst.x[layout.pc_final], F::from_u64(last.pc_after)); - for r in 0..32 { - assert_eq!(mcs_inst.x[layout.regs_final_start + r], F::from_u64(last.regs_after[r])); - } - assert_eq!(mcs_inst.x[layout.regs_final_start], F::ZERO); - - let mut x_bad = mcs_inst.x.clone(); - x_bad[layout.pc0] += F::ONE; - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &x_bad, &mcs_wit.w).is_err(), - "tampered pc0 should not satisfy CCS" - ); - - let mut x_bad = mcs_inst.x.clone(); - x_bad[layout.pc_final] += F::ONE; - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &x_bad, &mcs_wit.w).is_err(), - "tampered pc_final should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_rom_addr_mismatch() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mut mcs_wit) = steps.remove(0); - - let prog_cols = &layout.bus.twist_cols[layout.prog_twist_idx].lanes[0]; - let bit_col_id = prog_cols.ra_bits.start + 2; // keep alignment bits [0,1] untouched - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("prog bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "rom address mismatch should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_decode_bit_mismatch() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 8).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let (mcs_inst, mut mcs_wit) = steps.remove(0); - - let bit_z = layout.instr_bit(0, 0); - let bit_w_idx = bit_z - .checked_sub(layout.m_in) - .expect("instr_bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "decode bit mismatch should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_shout_key_bit_mismatch() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 2, - }, // x2 = 2 - RiscvInstruction::RAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 1, - rs2: 2, - }, // x3 = x1 + x2 (Shout ADD active) - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let add_step_idx = 2usize; - let (mcs_inst, mut mcs_wit) = steps.remove(add_step_idx); - - let add_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Add).0; - let add_shout_idx = layout.shout_idx(add_id).expect("ADD shout idx"); - let shout_cols = &layout.bus.shout_cols[add_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // rs1 bit 0 - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "shout key mismatch should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_shout_key_bit_mismatch_lw_eff_addr() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 1, - rs1: 0, - imm: 0, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(4); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let lw_step_idx = 0usize; - let (mcs_inst, mut mcs_wit) = steps.remove(lw_step_idx); - - let add_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Add).0; - let add_shout_idx = layout.shout_idx(add_id).expect("ADD shout idx"); - let shout_cols = &layout.bus.shout_cols[add_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // operand0 bit 0 - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "shout key mismatch (LW effective address) should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_shout_key_bit_mismatch_amoaddw() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 5, - }, - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 0, - rs2: 2, - imm: 0, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 7, - }, - RiscvInstruction::Amo { - op: RiscvMemOp::AmoaddW, - rd: 3, - rs1: 0, - rs2: 2, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(4); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let amoadd_step_idx = 3usize; - let (mcs_inst, mut mcs_wit) = steps.remove(amoadd_step_idx); - - let add_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Add).0; - let add_shout_idx = layout.shout_idx(add_id).expect("ADD shout idx"); - let shout_cols = &layout.bus.shout_cols[add_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // operand0 bit 0 (mem_rv bit 0) - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "shout key mismatch (AMOADD.W operands) should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_shout_key_bit_mismatch_beq() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 2, - }, // x2 = 2 - RiscvInstruction::Branch { - cond: BranchCondition::Eq, - rs1: 1, - rs2: 2, - imm: 8, - }, // not taken -> execute HALT - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(4); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let beq_step_idx = 2usize; - let (mcs_inst, mut mcs_wit) = steps.remove(beq_step_idx); - - let eq_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Eq).0; - let eq_shout_idx = layout.shout_idx(eq_id).expect("EQ shout idx"); - let shout_cols = &layout.bus.shout_cols[eq_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // operand0 bit 0 - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "shout key mismatch (BEQ operands) should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_shout_key_bit_mismatch_bne() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 5, - }, // x1 = 5 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 5, - }, // x2 = 5 - RiscvInstruction::Branch { - cond: BranchCondition::Ne, - rs1: 1, - rs2: 2, - imm: 8, - }, // not taken -> execute next - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 0, - imm: 7, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 32).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let bne_step_idx = 2usize; - let (mcs_inst, mut mcs_wit) = steps.remove(bne_step_idx); - - let neq_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Neq).0; - let neq_shout_idx = layout.shout_idx(neq_id).expect("NEQ shout idx"); - let shout_cols = &layout.bus.shout_cols[neq_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // operand0 bit 0 - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "shout key mismatch (BNE operands) should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_shout_key_bit_mismatch_ori() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 0x123, - }, // x1 = 0x123 - RiscvInstruction::IAlu { - op: RiscvOpcode::Or, - rd: 2, - rs1: 1, - imm: 0x55, - }, // x2 = x1 | 0x55 (Shout OR active) - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let ori_step_idx = 1usize; - let (mcs_inst, mut mcs_wit) = steps.remove(ori_step_idx); - - let or_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Or).0; - let or_shout_idx = layout.shout_idx(or_id).expect("OR shout idx"); - let shout_cols = &layout.bus.shout_cols[or_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // operand0 bit 0 - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "shout key mismatch (ORI imm) should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_shout_key_bit_mismatch_slli() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Sll, - rd: 2, - rs1: 1, - imm: 3, - }, // x2 = x1 << 3 (Shout SLL active) - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let slli_step_idx = 1usize; - let (mcs_inst, mut mcs_wit) = steps.remove(slli_step_idx); - - let sll_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Sll).0; - let sll_shout_idx = layout.shout_idx(sll_id).expect("SLL shout idx"); - let shout_cols = &layout.bus.shout_cols[sll_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // operand0 bit 0 - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "shout key mismatch (SLLI imm) should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_sltu_key_bit_mismatch_divu_remainder_check() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 10, - }, // x1 = 10 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 3, - }, // x2 = 3 - RiscvInstruction::RAlu { - op: RiscvOpcode::Divu, - rd: 3, - rs1: 1, - rs2: 2, - }, // x3 = x1 / x2 (sltu(rem, divisor) lookup when divisor != 0) - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let divu_step_idx = 2usize; - let (mcs_inst, mut mcs_wit) = steps.remove(divu_step_idx); - - let sltu_id = RiscvShoutTables::new(xlen) - .opcode_to_id(RiscvOpcode::Sltu) - .0; - let sltu_shout_idx = layout.shout_idx(sltu_id).expect("SLTU shout idx"); - let shout_cols = &layout.bus.shout_cols[sltu_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // operand0 bit 0 (remainder bit 0) - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "sltu(rem, divisor) shout key mismatch should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_shout_key_bit_mismatch_auipc_pc_operand() { - let xlen = 32usize; - let program = vec![RiscvInstruction::Auipc { rd: 1, imm: 0 }, RiscvInstruction::Halt]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let auipc_step_idx = 0usize; - let (mcs_inst, mut mcs_wit) = steps.remove(auipc_step_idx); - - let add_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Add).0; - let add_shout_idx = layout.shout_idx(add_id).expect("ADD shout idx"); - let shout_cols = &layout.bus.shout_cols[add_shout_idx].lanes[0]; - let bit_col_id = shout_cols.addr_bits.start + 0; // operand0 bit 0 (pc bit 0) - let bit_z = layout.bus.bus_cell(bit_col_id, 0); - let bit_w_idx = bit_z.checked_sub(layout.m_in).expect("bit in witness"); - let old_bit = mcs_wit.w[bit_w_idx]; - mcs_wit.w[bit_w_idx] = F::ONE - old_bit; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "shout key mismatch (AUIPC pc operand) should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_cheating_mul_hi_all_ones() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 1, - }, // x2 = 1 - RiscvInstruction::RAlu { - op: RiscvOpcode::Mul, - rd: 3, - rs1: 1, - rs2: 2, - }, // x3 = 1 * 1 - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let mul_step_idx = 2usize; - let (mcs_inst, mut mcs_wit) = steps.remove(mul_step_idx); - - let mul_hi = u32::MAX as u64; - let mul_lo = 2u64; - - let mul_hi_z = layout.mul_hi(0); - let mul_hi_w = mul_hi_z - .checked_sub(layout.m_in) - .expect("mul_hi in witness"); - mcs_wit.w[mul_hi_w] = F::from_u64(mul_hi); - - let mul_lo_z = layout.mul_lo(0); - let mul_lo_w = mul_lo_z - .checked_sub(layout.m_in) - .expect("mul_lo in witness"); - mcs_wit.w[mul_lo_w] = F::from_u64(mul_lo); - - let rd_write_z = layout.rd_write_val(0); - let rd_write_w = rd_write_z - .checked_sub(layout.m_in) - .expect("rd_write_val in witness"); - mcs_wit.w[rd_write_w] = F::from_u64(mul_lo); - - let reg_out_z = layout.reg_out(3, 0); - let reg_out_w = reg_out_z - .checked_sub(layout.m_in) - .expect("reg_out in witness"); - mcs_wit.w[reg_out_w] = F::from_u64(mul_lo); - - // Make the u32 bit decompositions consistent with the cheated values. - for bit in 0..32 { - let hi_bit_z = layout.mul_hi_bit(bit, 0); - let hi_bit_w = hi_bit_z - .checked_sub(layout.m_in) - .expect("mul_hi_bit in witness"); - mcs_wit.w[hi_bit_w] = F::ONE; - - let lo_bit = (mul_lo >> bit) & 1; - let lo_bit_z = layout.mul_lo_bit(bit, 0); - let lo_bit_w = lo_bit_z - .checked_sub(layout.m_in) - .expect("mul_lo_bit in witness"); - mcs_wit.w[lo_bit_w] = if lo_bit == 1 { F::ONE } else { F::ZERO }; - - let rd_bit_z = layout.rd_write_bit(bit, 0); - let rd_bit_w = rd_bit_z - .checked_sub(layout.m_in) - .expect("rd_write_bit in witness"); - mcs_wit.w[rd_bit_w] = if lo_bit == 1 { F::ONE } else { F::ZERO }; - } - for k in 0..31 { - let prefix_z = layout.mul_hi_prefix(k, 0); - let prefix_w = prefix_z - .checked_sub(layout.m_in) - .expect("mul_hi_prefix in witness"); - mcs_wit.w[prefix_w] = F::ONE; - } - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "cheating MUL decomposition should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_wrong_shout_table_activation() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, // x1 = 1 - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 2, - }, // x2 = 2 - RiscvInstruction::RAlu { - op: RiscvOpcode::Add, - rd: 3, - rs1: 1, - rs2: 2, - }, // x3 = x1 + x2 (ADD table active) - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let add_step_idx = 2usize; - let (mcs_inst, mut mcs_wit) = steps.remove(add_step_idx); - - let eq_id = RiscvShoutTables::new(xlen).opcode_to_id(RiscvOpcode::Eq).0; - let eq_shout_idx = layout.shout_idx(eq_id).expect("EQ shout idx"); - let eq_cols = &layout.bus.shout_cols[eq_shout_idx].lanes[0]; - let has_lookup_z = layout.bus.bus_cell(eq_cols.has_lookup, 0); - let has_lookup_w_idx = has_lookup_z - .checked_sub(layout.m_in) - .expect("has_lookup in witness"); - mcs_wit.w[has_lookup_w_idx] = F::ONE; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "wrong shout table activation should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_ram_read_value_mismatch() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 9, - }, // x1 = 9 - RiscvInstruction::Store { - op: RiscvMemOp::Sw, - rs1: 0, - rs2: 1, - imm: 0x100, - }, // mem[0x100] = x1 - RiscvInstruction::Load { - op: RiscvMemOp::Lw, - rd: 2, - rs1: 0, - imm: 0x100, - }, // x2 = mem[0x100] - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 64).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x200); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - 1, - ) - .expect("shared bus"); - - let mut steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); - let lw_step_idx = 2usize; - let (mcs_inst, mut mcs_wit) = steps.remove(lw_step_idx); - - let ram_cols = &layout.bus.twist_cols[layout.ram_twist_idx].lanes[0]; - let rv_z = layout.bus.bus_cell(ram_cols.rv, 0); - let rv_w_idx = rv_z.checked_sub(layout.m_in).expect("rv in witness"); - mcs_wit.w[rv_w_idx] += F::ONE; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "ram read value mismatch should not satisfy CCS" - ); -} - -#[test] -fn rv32_b1_ccs_rejects_chunk_size_2_continuity_break() { - let xlen = 32usize; - let program = vec![ - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 1, - rs1: 0, - imm: 1, - }, - RiscvInstruction::IAlu { - op: RiscvOpcode::Add, - rd: 2, - rs1: 0, - imm: 2, - }, - RiscvInstruction::Halt, - ]; - - let program_bytes = encode_program(&program); - let mut cpu_vm = RiscvCpu::new(xlen); - cpu_vm.load_program(0, program.clone()); - let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); - let shout = RiscvShoutTables::new(xlen); - let trace = trace_program(cpu_vm, memory, shout, 16).expect("trace"); - assert!(trace.did_halt(), "expected Halt"); - - let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); - let (k_ram, d_ram) = pow2_ceil_k(0x80); - let mem_layouts = HashMap::from([ - ( - 0u32, - PlainMemLayout { - k: k_ram, - d: d_ram, - n_side: 2, - lanes: 1, - }, - ), - ( - 1u32, - PlainMemLayout { - k: k_prog, - d: d_prog, - n_side: 2, - lanes: 1, - }, - ), - ]); - let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); - - let shout_table_ids = RV32I_SHOUT_TABLE_IDS; - let chunk_size = 2usize; - let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, chunk_size).expect("ccs"); - let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); - let table_specs = rv32i_table_specs(xlen); - - let cpu = R1csCpu::new( - ccs, - params, - NoopCommit::default(), - layout.m_in, - &HashMap::new(), - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ) - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), - chunk_size, - ) - .expect("shared bus"); - - let mut chunks = CpuArithmetization::build_ccs_chunks(&cpu, &trace, chunk_size).expect("build chunks"); - let (mcs_inst, mut mcs_wit) = chunks.remove(0); - - let pc_in_w_idx = layout - .pc_in(1) - .checked_sub(layout.m_in) - .expect("pc_in lane 1 in witness"); - mcs_wit.w[pc_in_w_idx] += F::ONE; - - assert!( - check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).is_err(), - "continuity break should not satisfy CCS" - ); + assert_eq!(lanes.len(), 1, "decode lookup should bind exactly one lane"); + let lane = &lanes[0]; + assert_eq!(lane.has_lookup, CPU_BUS_COL_DISABLED); + assert_eq!(lane.val, CPU_BUS_COL_DISABLED); + assert_eq!(lane.addr, Some(layout.cell(layout.trace.pc_before, 0))); } diff --git a/crates/neo-memory/tests/riscv_exec_table.rs b/crates/neo-memory/tests/riscv_exec_table.rs new file mode 100644 index 00000000..64d63d7e --- /dev/null +++ b/crates/neo-memory/tests/riscv_exec_table.rs @@ -0,0 +1,238 @@ +use neo_memory::riscv::exec_table::{Rv32ExecTable, Rv32ShoutEventTable}; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, interleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_vm_trace::trace_program; + +#[test] +fn rv32_exec_table_matches_trace_lane_conventions_addi_halt() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + + // Initialize only PROG; RAM starts empty/zeroed and REG starts with all-zero regs. + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + assert_eq!(trace.steps.len(), 2, "expected ADDI + HALT trace"); + + let table = Rv32ExecTable::from_trace(&trace).expect("Rv32ExecTable::from_trace"); + table.validate_pc_chain().expect("pc chain"); + assert_eq!(table.rows.len(), 2); + + // Step 0: ADDI x1,x0,1 + { + let row0 = &table.rows[0]; + assert!(row0.active); + assert_eq!(row0.pc_before, 0); + assert_eq!(row0.pc_after, 4); + assert_eq!(row0.fields.opcode, 0x13); + assert_eq!(row0.fields.rs1, 0); + assert_eq!(row0.fields.rd, 1); + + // PROG fetch matches the instruction word for this row. + let prog_read = row0.prog_read.as_ref().expect("expected PROG read"); + assert_eq!(prog_read.addr, row0.pc_before); + assert_eq!(prog_read.value, row0.instr_word as u64); + + // REG lane policy: lane0 reads rs1_field, lane1 reads rs2_field. + let rs1 = row0.reg_read_lane0.as_ref().expect("expected rs1 read"); + let rs2 = row0.reg_read_lane1.as_ref().expect("expected rs2 read"); + assert_eq!(rs1.addr, 0); + assert_eq!(rs1.value, 0); + assert_eq!(rs2.addr, row0.fields.rs2 as u64); + + // Writeback: rd_field=1 should be written with value 1. + let w = row0.reg_write_lane0.as_ref().expect("expected rd write"); + assert_eq!(w.addr, 1); + assert_eq!(w.value, 1); + + // ADDI uses one ADD shout lookup: key = interleave(rs1_val, imm_u32), value = rd. + let add_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add); + assert_eq!(row0.shout_events.len(), 1); + let ev = &row0.shout_events[0]; + assert_eq!(ev.shout_id, add_id); + + let imm_u32 = 1u64; + let expected_key = interleave_bits(rs1.value, imm_u32) as u64; + assert_eq!(ev.key, expected_key); + assert_eq!(ev.value, 1); + } + + // Step 1: HALT (ECALL). Lane1 still reads rs2_field (which is 0 for ECALL). + { + let row1 = &table.rows[1]; + assert!(row1.active); + assert_eq!(row1.pc_before, 4); + assert_eq!(row1.fields.opcode, 0x73); + assert!(row1.halted); + + let prog_read = row1.prog_read.as_ref().expect("expected PROG read"); + assert_eq!(prog_read.addr, row1.pc_before); + assert_eq!(prog_read.value, row1.instr_word as u64); + + let rs1 = row1.reg_read_lane0.as_ref().expect("expected rs1 read"); + let rs2 = row1.reg_read_lane1.as_ref().expect("expected rs2 read"); + assert_eq!(rs1.addr, row1.fields.rs1 as u64); + assert_eq!(rs2.addr, row1.fields.rs2 as u64); + assert!(row1.reg_write_lane0.is_none()); + assert!(row1.shout_events.is_empty()); + } +} + +#[test] +fn rv32_exec_table_padding_builds_inactive_rows() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + assert_eq!(trace.steps.len(), 2, "expected ADDI + HALT trace"); + + let table = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + assert_eq!(table.rows.len(), 4); + table.validate_pc_chain().expect("pc chain"); + + // First two rows are active; tail rows are inactive padding with no side effects. + assert!(table.rows[0].active); + assert!(table.rows[1].active); + assert!(!table.rows[2].active); + assert!(!table.rows[3].active); + + let halted_pc = table.rows[1].pc_after; + for r in table.rows.iter().skip(2) { + assert_eq!(r.pc_before, halted_pc); + assert_eq!(r.pc_after, halted_pc); + assert!(r.halted, "padded rows should stay halted"); + assert!(r.prog_read.is_none()); + assert!(r.reg_read_lane0.is_none()); + assert!(r.reg_read_lane1.is_none()); + assert!(r.reg_write_lane0.is_none()); + assert!(r.ram_events.is_empty()); + assert!(r.shout_events.is_empty()); + } + + let cols = table.to_columns(); + assert_eq!(cols.len(), 4); + assert_eq!(cols.active, vec![true, true, false, false]); + assert_eq!(cols.pc_before[2], halted_pc); + assert_eq!(cols.pc_after[3], halted_pc); + assert_eq!(cols.prog_value[2], 0); + assert!(!cols.rd_has_write[3]); +} + +#[test] +fn rv32_shout_event_table_includes_rv32m_rows() { + // Target production behavior: RV32M Shout-backed ops should appear in the + // trace-derived event table used by event-table packed proving paths. + // + // This test is expected to fail until RV32M event-table coverage is fully wired. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 7, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 9, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let events = Rv32ShoutEventTable::from_exec_table(&exec).expect("Rv32ShoutEventTable::from_exec_table"); + + assert!( + events + .rows + .iter() + .any(|row| row.opcode == Some(RiscvOpcode::Mul)), + "expected RV32M (MUL) rows in trace shout event table" + ); +} + +#[test] +fn rv32_exec_table_rejects_jalr_non_strict_target_tamper() { + // Program: + // ADDI x1, x0, 8 + // JALR x2, x1, 0 + // HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 8, + }, + RiscvInstruction::Jalr { rd: 2, rs1: 1, imm: 0 }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + let mut table = Rv32ExecTable::from_trace(&trace).expect("Rv32ExecTable::from_trace"); + + // Tamper the JALR row target so it no longer equals rs1+imm under strict policy. + let jalr_row = table + .rows + .iter_mut() + .find(|r| matches!(r.decoded, Some(RiscvInstruction::Jalr { .. }))) + .expect("expected one JALR row"); + jalr_row.pc_after = jalr_row.pc_after.wrapping_add(4); + + assert!( + table.validate_jalr_strict_alignment_policy().is_err(), + "tampered JALR target should fail strict alignment policy validation" + ); +} diff --git a/crates/neo-memory/tests/riscv_rv32m_event_table.rs b/crates/neo-memory/tests/riscv_rv32m_event_table.rs new file mode 100644 index 00000000..80cbc532 --- /dev/null +++ b/crates/neo-memory/tests/riscv_rv32m_event_table.rs @@ -0,0 +1,162 @@ +use neo_memory::riscv::exec_table::{Rv32ExecTable, Rv32MEventTable}; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_vm_trace::trace_program; + +#[test] +fn rv32m_event_table_extracts_and_matches_cpu_semantics() { + // Program: + // ADDI x1,x0,3 + // ADDI x2,x0,5 + // MUL x3,x1,x2 -> 15 + // DIVU x4,x2,x1 -> 1 + // REMU x5,x2,x1 -> 2 + // HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 3, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 5, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + rd: 4, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + rd: 5, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 64).expect("trace_program"); + assert!(trace.did_halt(), "expected program to halt"); + + let exec = Rv32ExecTable::from_trace(&trace).expect("Rv32ExecTable::from_trace"); + let events = Rv32MEventTable::from_exec_table(&exec).expect("Rv32MEventTable::from_exec_table"); + assert_eq!(events.rows.len(), 3, "expected MUL/DIVU/REMU events"); + + // MUL x3,x1,x2: 3*5 = 15 + { + let e = &events.rows[0]; + assert_eq!(e.opcode, RiscvOpcode::Mul); + assert_eq!(e.rs1, 1); + assert_eq!(e.rs2, 2); + assert_eq!(e.rd, 3); + assert_eq!(e.rs1_val, 3); + assert_eq!(e.rs2_val, 5); + assert_eq!(e.expected_rd_val, 15); + assert_eq!(e.rd_write_val, Some(15)); + } + + // DIVU x4,x2,x1: 5/3 = 1 + { + let e = &events.rows[1]; + assert_eq!(e.opcode, RiscvOpcode::Divu); + assert_eq!(e.rs1_val, 5); + assert_eq!(e.rs2_val, 3); + assert_eq!(e.expected_rd_val, 1); + assert_eq!(e.rd_write_val, Some(1)); + } + + // REMU x5,x2,x1: 5%3 = 2 + { + let e = &events.rows[2]; + assert_eq!(e.opcode, RiscvOpcode::Remu); + assert_eq!(e.rs1_val, 5); + assert_eq!(e.rs2_val, 3); + assert_eq!(e.expected_rd_val, 2); + assert_eq!(e.rd_write_val, Some(2)); + } +} + +#[test] +fn rv32m_event_table_is_empty_when_program_has_no_rv32m_ops() { + // Program: ADDI; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 64).expect("trace_program"); + let exec = Rv32ExecTable::from_trace(&trace).expect("Rv32ExecTable::from_trace"); + let events = Rv32MEventTable::from_exec_table(&exec).expect("Rv32MEventTable::from_exec_table"); + + assert!(events.rows.is_empty(), "expected no RV32M events"); +} + +#[test] +fn rv32m_event_table_single_row_for_single_mul() { + // Program: MUL x1,x0,x0; HALT + let program = vec![ + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 1, + rs1: 0, + rs2: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 64).expect("trace_program"); + let exec = Rv32ExecTable::from_trace(&trace).expect("Rv32ExecTable::from_trace"); + let events = Rv32MEventTable::from_exec_table(&exec).expect("Rv32MEventTable::from_exec_table"); + + assert_eq!(events.rows.len(), 1, "expected exactly one RV32M event"); + let e = &events.rows[0]; + assert_eq!(e.opcode, RiscvOpcode::Mul); + assert_eq!(e.rs1, 0); + assert_eq!(e.rs2, 0); + assert_eq!(e.rd, 1); + assert_eq!(e.rs1_val, 0); + assert_eq!(e.rs2_val, 0); + assert_eq!(e.expected_rd_val, 0); + assert_eq!(e.rd_write_val, Some(0)); +} diff --git a/crates/neo-memory/tests/riscv_rv32m_masked_columns.rs b/crates/neo-memory/tests/riscv_rv32m_masked_columns.rs new file mode 100644 index 00000000..8ab0a8b1 --- /dev/null +++ b/crates/neo-memory/tests/riscv_rv32m_masked_columns.rs @@ -0,0 +1,107 @@ +use neo_memory::riscv::exec_table::{Rv32ExecTable, Rv32ShoutEventTable}; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, interleave_bits, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_vm_trace::trace_program; + +fn rv32m_exec_table() -> Rv32ExecTable { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 3, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 5, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulh, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + rd: 4, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + rd: 5, + rs1: 2, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + exec +} + +#[test] +fn rv32_trace_shout_event_table_includes_rv32m_rows() { + let exec = rv32m_exec_table(); + let events = Rv32ShoutEventTable::from_exec_table(&exec).expect("Rv32ShoutEventTable::from_exec_table"); + + assert!( + events + .rows + .iter() + .any(|row| row.opcode == Some(RiscvOpcode::Mulh)), + "expected MULH shout event row" + ); + assert!( + exec.rows.iter().any(|row| matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + .. + }) + )), + "expected DIVU step in execution table" + ); + assert!( + exec.rows.iter().any(|row| matches!( + row.decoded, + Some(RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + .. + }) + )), + "expected REMU step in execution table" + ); +} + +#[test] +fn rv32_trace_shout_event_table_mulh_key_matches_operands() { + let exec = rv32m_exec_table(); + let events = Rv32ShoutEventTable::from_exec_table(&exec).expect("Rv32ShoutEventTable::from_exec_table"); + + let mulh_row = events + .rows + .iter() + .find(|row| row.opcode == Some(RiscvOpcode::Mulh)) + .expect("expected MULH row"); + + let expected_key = interleave_bits(/*lhs=*/ 3, /*rhs=*/ 5) as u64; + assert_eq!(mulh_row.key, expected_key, "MULH key must encode rs1/rs2 values"); +} diff --git a/crates/neo-memory/tests/riscv_shout_event_table.rs b/crates/neo-memory/tests/riscv_shout_event_table.rs new file mode 100644 index 00000000..4c0edc03 --- /dev/null +++ b/crates/neo-memory/tests/riscv_shout_event_table.rs @@ -0,0 +1,168 @@ +use std::collections::HashMap; + +use neo_memory::riscv::exec_table::{Rv32ExecTable, Rv32ShoutEventTable}; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::extract_shout_lanes_over_time; +use neo_vm_trace::trace_program; + +#[test] +fn rv32_shout_event_table_matches_fixed_lane_extract() { + // Program: + // - ADDI x1,x0,0x1234 + // - ADDI x2,x0,37 + // - SLL x3,x1,x2 (shamt uses low 5 bits => 5) + // - OR x4,x1,x0 + // - HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 0x1234, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 37, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Sll, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Or, + rd: 4, + rs1: 1, + rs2: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 64).expect("trace_program"); + assert!(trace.did_halt(), "expected program to halt"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let shout_table_ids = vec![ + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add).0, + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sll).0, + RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Or).0, + ]; + let lanes = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("extract shout lanes"); + assert_eq!(lanes.len(), shout_table_ids.len()); + + let table = Rv32ShoutEventTable::from_exec_table(&exec).expect("Rv32ShoutEventTable::from_exec_table"); + + // Index events by (row_idx, shout_id); fixed-lane policy should make this unique. + let mut by_row: HashMap<(usize, u32), (u64, u64)> = HashMap::new(); + for e in table.rows.iter() { + assert!( + by_row + .insert((e.row_idx, e.shout_id), (e.key, e.value)) + .is_none(), + "duplicate shout event at row_idx={} shout_id={}", + e.row_idx, + e.shout_id + ); + } + + // For each provisioned shout lane, ensure the per-row key/value matches the event table. + let t = exec.rows.len(); + let mut expected_event_count = 0usize; + for (lane_idx, &shout_id) in shout_table_ids.iter().enumerate() { + let lane = &lanes[lane_idx]; + for row_idx in 0..t { + if lane.has_lookup[row_idx] { + expected_event_count += 1; + let (key, value) = by_row + .get(&(row_idx, shout_id)) + .copied() + .unwrap_or_else(|| panic!("missing shout event row_idx={row_idx} shout_id={shout_id}")); + assert_eq!( + key, lane.key[row_idx], + "key mismatch at row_idx={row_idx} shout_id={shout_id}" + ); + assert_eq!( + value, lane.value[row_idx], + "value mismatch at row_idx={row_idx} shout_id={shout_id}" + ); + } + } + } + assert_eq!(table.rows.len(), expected_event_count, "unexpected shout event count"); + + // Shift canonicalization: the SLL event rhs should be masked to 5 bits. + let sll_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Sll).0; + let sll_ev = table + .rows + .iter() + .find(|e| e.shout_id == sll_id) + .expect("expected SLL shout event"); + assert!(sll_ev.rhs <= 31, "expected canonicalized SLL rhs <= 31"); +} + +#[test] +fn rv32_shout_event_table_is_empty_for_halt_only_program() { + let program = vec![RiscvInstruction::Halt]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 8).expect("trace_program"); + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + + let table = Rv32ShoutEventTable::from_exec_table(&exec).expect("Rv32ShoutEventTable::from_exec_table"); + assert!(table.rows.is_empty(), "expected no shout events for HALT-only program"); +} + +#[test] +fn rv32_shout_event_table_single_row_for_single_xor() { + // Program: XOR x1,x0,x0; HALT + let program = vec![ + RiscvInstruction::RAlu { + op: RiscvOpcode::Xor, + rd: 1, + rs1: 0, + rs2: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + + let table = Rv32ShoutEventTable::from_exec_table(&exec).expect("Rv32ShoutEventTable::from_exec_table"); + assert_eq!(table.rows.len(), 1, "expected exactly one shout event"); + + let ev = &table.rows[0]; + let xor_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Xor).0; + assert_eq!(ev.shout_id, xor_id); + assert_eq!(ev.lhs, 0); + assert_eq!(ev.rhs, 0); + assert_eq!(ev.value, 0); +} diff --git a/crates/neo-memory/tests/riscv_shout_event_table_sparse_matrix.rs b/crates/neo-memory/tests/riscv_shout_event_table_sparse_matrix.rs new file mode 100644 index 00000000..6f0a3663 --- /dev/null +++ b/crates/neo-memory/tests/riscv_shout_event_table_sparse_matrix.rs @@ -0,0 +1,117 @@ +use neo_math::K; +use neo_memory::riscv::exec_table::{Rv32ShoutEventRow, Rv32ShoutEventTable}; +use neo_memory::riscv::sparse_access::{ + rv32_shout_event_table_ra_val_mle_eval_chunked, rv32_shout_event_table_to_sparse_ra_and_val, +}; +use p3_field::PrimeCharacteristicRing; +use rand::Rng; +use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha8Rng; + +fn chi_at_u64_index(r: &[K], idx: u64) -> K { + let mut acc = K::ONE; + for (bit, &ri) in r.iter().enumerate() { + let is_one = ((idx >> bit) & 1) == 1; + acc *= if is_one { ri } else { K::ONE - ri }; + } + acc +} + +#[test] +fn shout_event_table_sparse_ra_val_mle_matches_direct_sum() { + let mut rng = ChaCha8Rng::seed_from_u64(2026); + + // Small cycle domain so the direct sum is easy to compute. + let ell_cycle = 3usize; + let max_cycle = 1usize << ell_cycle; + + // A tiny synthetic Shout event table (keys are arbitrary u64s here). + let rows = vec![ + Rv32ShoutEventRow { + row_idx: 0, + cycle: 0, + pc: 0, + shout_id: 1, + opcode: None, + key: 0x0123_4567_89ab_cdef, + lhs: 0, + rhs: 0, + value: 11, + }, + Rv32ShoutEventRow { + row_idx: 1, + cycle: 1, + pc: 4, + shout_id: 2, + opcode: None, + key: 0xfedc_ba98_7654_3210, + lhs: 0, + rhs: 0, + value: 22, + }, + // Duplicate address at a different cycle (allowed). + Rv32ShoutEventRow { + row_idx: 4, + cycle: 4, + pc: 16, + shout_id: 1, + opcode: None, + key: 0x0123_4567_89ab_cdef, + lhs: 0, + rhs: 0, + value: 33, + }, + // Same cycle, different address (also allowed at the event-table layer). + Rv32ShoutEventRow { + row_idx: 4, + cycle: 4, + pc: 16, + shout_id: 3, + opcode: None, + key: 0x0000_0000_dead_beef, + lhs: 0, + rhs: 0, + value: 44, + }, + ]; + for r in rows.iter() { + assert!(r.row_idx < max_cycle, "row_idx must fit ell_cycle"); + } + let events = Rv32ShoutEventTable { rows }; + + let (ra, val) = rv32_shout_event_table_to_sparse_ra_and_val(&events, ell_cycle).expect("sparse mats"); + + // Random evaluation points. + let r_addr: Vec = (0..64).map(|_| K::from_u64(rng.random::())).collect(); + let r_cycle: Vec = (0..ell_cycle) + .map(|_| K::from_u64(rng.random::())) + .collect(); + + // Direct expected sums. + let mut expected_ra = K::ZERO; + let mut expected_val = K::ZERO; + for row in events.rows.iter() { + let addr = row.key; + let cycle = row.row_idx as u64; + let w = chi_at_u64_index(&r_addr, addr) * chi_at_u64_index(&r_cycle, cycle); + expected_ra += w; + expected_val += K::from_u64(row.value) * w; + } + + let got_ra = ra + .mle_eval_by_folding(&r_addr, &r_cycle) + .expect("ra mle_eval_by_folding"); + let got_val = val + .mle_eval_by_folding(&r_addr, &r_cycle) + .expect("val mle_eval_by_folding"); + + assert_eq!(got_ra, expected_ra); + assert_eq!(got_val, expected_val); + + // Chunked Jolt-style eq-table evaluation (log_k_chunk=16 → 4 chunks). + let (got_ra_chunked, got_val_chunked) = + rv32_shout_event_table_ra_val_mle_eval_chunked(&events, &r_addr, &r_cycle, /*log_k_chunk=*/ 16) + .expect("chunked eval"); + assert_eq!(got_ra_chunked, expected_ra); + assert_eq!(got_val_chunked, expected_val); +} diff --git a/crates/neo-memory/tests/riscv_signed_div_rem_shared_bus_constraints.rs b/crates/neo-memory/tests/riscv_signed_div_rem_shared_bus_constraints.rs new file mode 100644 index 00000000..f29f39ac --- /dev/null +++ b/crates/neo-memory/tests/riscv_signed_div_rem_shared_bus_constraints.rs @@ -0,0 +1,113 @@ +use std::collections::HashMap; + +use neo_memory::plain::PlainMemLayout; +use neo_memory::riscv::ccs::{ + rv32_trace_shared_bus_requirements_with_specs, rv32_trace_shared_cpu_bus_config_with_specs, Rv32TraceCcsLayout, + TraceShoutBusSpec, +}; +use neo_memory::riscv::lookups::{RiscvOpcode, RiscvShoutTables, PROG_ID, RAM_ID, REG_ID}; +use neo_memory::riscv::trace::{rv32_decode_lookup_table_id_for_col, Rv32DecodeSidecarLayout}; +use p3_goldilocks::Goldilocks as F; + +fn sample_mem_layouts() -> HashMap { + HashMap::from([ + ( + PROG_ID.0, + PlainMemLayout { + k: 16, + d: 4, + n_side: 2, + lanes: 1, + }, + ), + ( + REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + ( + RAM_ID.0, + PlainMemLayout { + k: 16, + d: 4, + n_side: 2, + lanes: 1, + }, + ), + ]) +} + +fn decode_selector_specs(prog_d: usize) -> Vec { + let decode = Rv32DecodeSidecarLayout::new(); + [decode.rd_has_write, decode.ram_has_read, decode.ram_has_write] + .into_iter() + .map(|col| TraceShoutBusSpec { + table_id: rv32_decode_lookup_table_id_for_col(col), + ell_addr: prog_d, + n_vals: 1usize, + }) + .collect() +} + +fn div_rem_table_ids() -> Vec { + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + [RiscvOpcode::Div, RiscvOpcode::Divu, RiscvOpcode::Rem, RiscvOpcode::Remu] + .into_iter() + .map(|op| shout.opcode_to_id(op).0) + .collect() +} + +#[test] +fn rv32_trace_shared_bus_requirements_accept_div_and_rem_tables() { + let layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let table_ids = div_rem_table_ids(); + + let (bus_region_len, reserved_rows) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &table_ids, &decode_specs, &mem_layouts) + .expect("trace shared bus requirements for DIV/REM tables"); + + assert!(bus_region_len > 0, "expected non-zero bus region for DIV/REM tables"); + assert!( + reserved_rows > 0, + "expected injected constraints when shared-bus rows are reserved" + ); +} + +#[test] +fn rv32_trace_shared_bus_config_keeps_div_and_rem_tables_padding_only() { + let mut layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let table_ids = div_rem_table_ids(); + + let (bus_region_len, _) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &table_ids, &decode_specs, &mem_layouts) + .expect("trace shared bus requirements for DIV/REM tables"); + layout.m += bus_region_len; + + let cfg = rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &table_ids, + &decode_specs, + mem_layouts, + HashMap::<(u32, u64), F>::new(), + ) + .expect("trace shared bus config"); + + for &table_id in &table_ids { + let lanes = cfg + .shout_cpu + .get(&table_id) + .expect("missing shout_cpu entry for DIV/REM table"); + assert!( + lanes.is_empty(), + "trace mode must keep DIV/REM tables as padding-only shout bindings (table_id={table_id})" + ); + } +} diff --git a/crates/neo-memory/tests/riscv_single_instruction_constraints.rs b/crates/neo-memory/tests/riscv_single_instruction_constraints.rs new file mode 100644 index 00000000..5bbd6f25 --- /dev/null +++ b/crates/neo-memory/tests/riscv_single_instruction_constraints.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; + +use neo_memory::plain::PlainMemLayout; +use neo_memory::riscv::ccs::{ + build_rv32_trace_wiring_ccs, build_rv32_trace_wiring_ccs_with_reserved_rows, + rv32_trace_shared_bus_requirements_with_specs, Rv32TraceCcsLayout, TraceShoutBusSpec, +}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, + RAM_ID, REG_ID, +}; +use neo_memory::riscv::trace::{rv32_decode_lookup_table_id_for_col, Rv32DecodeSidecarLayout}; +use neo_vm_trace::trace_program; + +fn addi_halt_exec_table() -> Rv32ExecTable { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + exec +} + +fn sample_mem_layouts() -> HashMap { + HashMap::from([ + ( + PROG_ID.0, + PlainMemLayout { + k: 16, + d: 4, + n_side: 2, + lanes: 1, + }, + ), + ( + REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + ( + RAM_ID.0, + PlainMemLayout { + k: 16, + d: 4, + n_side: 2, + lanes: 1, + }, + ), + ]) +} + +fn decode_selector_specs(prog_d: usize) -> Vec { + let decode = Rv32DecodeSidecarLayout::new(); + [decode.rd_has_write, decode.ram_has_read, decode.ram_has_write] + .into_iter() + .map(|col| TraceShoutBusSpec { + table_id: rv32_decode_lookup_table_id_for_col(col), + ell_addr: prog_d, + n_vals: 1usize, + }) + .collect() +} + +#[test] +fn trace_single_addi_constraint_shapes_are_consistent() { + let exec = addi_halt_exec_table(); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + assert_eq!(layout.t, 4, "expected power-of-two padded rows for ADDI+HALT"); + assert_eq!( + layout.m, + layout.m_in + layout.trace.cols * layout.t, + "layout width regression" + ); + assert_eq!(ccs.m, layout.m, "CCS witness width must match layout"); + assert!(ccs.n > layout.t, "CCS must contain non-trivial constraints"); +} + +#[test] +fn trace_single_addi_reserved_rows_affect_constraints_only() { + let exec = addi_halt_exec_table(); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let ccs_base = build_rv32_trace_wiring_ccs(&layout).expect("base trace CCS"); + + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let add_table_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add).0; + let (_bus_region_len, reserved_rows) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &[add_table_id], &decode_specs, &mem_layouts) + .expect("trace shared bus requirements"); + + let ccs_reserved = + build_rv32_trace_wiring_ccs_with_reserved_rows(&layout, reserved_rows).expect("trace CCS with reserved rows"); + + assert!( + reserved_rows > 0, + "expected non-zero reserved rows for shared bus padding" + ); + assert_eq!( + ccs_reserved.n, + ccs_base.n + reserved_rows, + "reserved rows should only increase row count" + ); + assert_eq!( + ccs_reserved.m, ccs_base.m, + "reserved rows should not change witness width" + ); +} diff --git a/crates/neo-memory/tests/riscv_trace_air.rs b/crates/neo-memory/tests/riscv_trace_air.rs new file mode 100644 index 00000000..9df96bc4 --- /dev/null +++ b/crates/neo-memory/tests/riscv_trace_air.rs @@ -0,0 +1,123 @@ +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, +}; +use neo_memory::riscv::trace::{Rv32TraceAir, Rv32TraceWitness}; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +#[test] +fn rv32_trace_air_satisfies_addi_halt() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + assert_eq!(trace.steps.len(), 2, "expected ADDI + HALT trace"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let air = Rv32TraceAir::new(); + let wit = Rv32TraceWitness::from_exec_table(&air.layout, &exec).expect("trace witness"); + air.assert_satisfied(&wit).expect("trace AIR satisfied"); +} + +#[test] +fn rv32_trace_air_rejects_halted_tail_reactivation() { + // Program with at least one active transition after row 0. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + + let air = Rv32TraceAir::new(); + let mut wit = Rv32TraceWitness::from_exec_table(&air.layout, &exec).expect("trace witness"); + + // Force halted=1 from row 0 onward, while keeping row 1 active=1 in the witness. + // This should violate the halted tail quiescence transition check at row 0. + for row in 0..wit.t { + wit.cols[air.layout.halted][row] = F::ONE; + } + + let err = air + .assert_satisfied(&wit) + .expect_err("mutated witness should violate halted tail quiescence"); + assert!( + err.contains("halted tail quiescence violated"), + "unexpected error: {err}" + ); +} + +#[test] +fn rv32_trace_air_rejects_non_boolean_active() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + + let air = Rv32TraceAir::new(); + let mut wit = Rv32TraceWitness::from_exec_table(&air.layout, &exec).expect("trace witness"); + wit.cols[air.layout.active][0] = F::from_u64(2); + + let err = air + .assert_satisfied(&wit) + .expect_err("mutated witness should violate bit booleanity"); + assert!(err.contains("active not boolean"), "unexpected error: {err}"); +} diff --git a/crates/neo-memory/tests/riscv_trace_shared_bus_w1.rs b/crates/neo-memory/tests/riscv_trace_shared_bus_w1.rs new file mode 100644 index 00000000..33d75de7 --- /dev/null +++ b/crates/neo-memory/tests/riscv_trace_shared_bus_w1.rs @@ -0,0 +1,398 @@ +use std::collections::HashMap; + +use neo_memory::cpu::CPU_BUS_COL_DISABLED; +use neo_memory::plain::PlainMemLayout; +use neo_memory::riscv::ccs::{ + rv32_trace_shared_bus_requirements_with_specs, rv32_trace_shared_cpu_bus_config_with_specs, Rv32TraceCcsLayout, + TraceShoutBusSpec, +}; +use neo_memory::riscv::lookups::{RiscvOpcode, RiscvShoutTables, PROG_ID, RAM_ID, REG_ID}; +use neo_memory::riscv::trace::{ + rv32_decode_lookup_backed_cols, rv32_decode_lookup_table_id_for_col, rv32_trace_lookup_addr_group_for_table_id, + rv32_trace_lookup_selector_group_for_table_id, rv32_width_lookup_backed_cols, rv32_width_lookup_table_id_for_col, + Rv32DecodeSidecarLayout, Rv32WidthSidecarLayout, +}; +use p3_goldilocks::Goldilocks as F; + +fn sample_mem_layouts() -> HashMap { + HashMap::from([ + ( + PROG_ID.0, + PlainMemLayout { + k: 16, + d: 4, + n_side: 2, + lanes: 1, + }, + ), + ( + REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + ( + RAM_ID.0, + PlainMemLayout { + k: 16, + d: 4, + n_side: 2, + lanes: 1, + }, + ), + ]) +} + +fn decode_selector_specs(prog_d: usize) -> Vec { + let decode = Rv32DecodeSidecarLayout::new(); + [decode.rd_has_write, decode.ram_has_read, decode.ram_has_write] + .into_iter() + .map(|col| TraceShoutBusSpec { + table_id: rv32_decode_lookup_table_id_for_col(col), + ell_addr: prog_d, + n_vals: 1usize, + }) + .collect() +} + +fn width_selector_specs(cycle_d: usize) -> Vec { + let width = Rv32WidthSidecarLayout::new(); + [width.ram_rv_q16, width.rs2_q16] + .into_iter() + .map(|col| TraceShoutBusSpec { + table_id: rv32_width_lookup_table_id_for_col(col), + ell_addr: cycle_d, + n_vals: 1usize, + }) + .collect() +} + +fn full_table_ids() -> Vec { + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + [ + RiscvOpcode::And, + RiscvOpcode::Xor, + RiscvOpcode::Or, + RiscvOpcode::Add, + RiscvOpcode::Sub, + RiscvOpcode::Slt, + RiscvOpcode::Sltu, + RiscvOpcode::Sll, + RiscvOpcode::Srl, + RiscvOpcode::Sra, + RiscvOpcode::Eq, + RiscvOpcode::Neq, + RiscvOpcode::Mul, + RiscvOpcode::Mulh, + RiscvOpcode::Mulhu, + RiscvOpcode::Mulhsu, + RiscvOpcode::Div, + RiscvOpcode::Divu, + RiscvOpcode::Rem, + RiscvOpcode::Remu, + ] + .into_iter() + .map(|op| shout.opcode_to_id(op).0) + .collect() +} + +#[test] +fn rv32_trace_shared_bus_config_uses_padding_only_shout_bindings_for_all_tables() { + let mut layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let table_ids = full_table_ids(); + let (bus_region_len, _) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &table_ids, &decode_specs, &mem_layouts) + .expect("trace shared bus requirements"); + layout.m += bus_region_len; + let cfg = rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &table_ids, + &decode_specs, + mem_layouts, + HashMap::<(u32, u64), F>::new(), + ) + .expect("trace shared bus config"); + + for &table_id in &table_ids { + let lanes = cfg + .shout_cpu + .get(&table_id) + .expect("missing shout_cpu entry for table"); + assert!( + lanes.is_empty(), + "trace shared bus must use padding-only shout bindings (table_id={table_id})" + ); + } +} + +#[test] +fn rv32_trace_shared_bus_requirements_accept_rv32m_table_ids() { + let layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let table_ids = full_table_ids(); + let (bus_region_len, reserved_rows) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &table_ids, &decode_specs, &mem_layouts) + .expect("trace shared bus requirements"); + assert!( + bus_region_len > 0, + "expected non-zero bus region for full table profile" + ); + assert!( + reserved_rows > 0, + "expected injected bus constraints for shout padding rows" + ); +} + +#[test] +fn rv32_trace_shared_bus_requirements_reject_unknown_table_id() { + let layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let err = rv32_trace_shared_bus_requirements_with_specs(&layout, &[999u32], &decode_specs, &mem_layouts) + .expect_err("unknown table id must be rejected"); + assert!( + err.contains("unsupported shout table_id=999"), + "unexpected error: {err}" + ); +} + +#[test] +fn rv32_trace_shared_bus_with_specs_adds_custom_shout_width() { + let layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let mut specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let (bus_region_base, _) = rv32_trace_shared_bus_requirements_with_specs(&layout, &[3u32], &specs, &mem_layouts) + .expect("trace shared bus baseline requirements"); + specs.push(TraceShoutBusSpec { + table_id: 1000, + ell_addr: 13, + n_vals: 1usize, + }); + let (bus_region_len, reserved_rows) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &[3u32], &specs, &mem_layouts) + .expect("trace shared bus requirements with extra spec"); + + let expected_extra_cols = 13 + 2; + assert_eq!( + bus_region_len - bus_region_base, + expected_extra_cols * layout.t, + "bus width delta must include custom extra shout ell_addr" + ); + assert!(reserved_rows > 0, "expected injected padding constraints"); +} + +#[test] +fn rv32_trace_shared_cpu_bus_config_with_specs_keeps_padding_only_bindings() { + let mut layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let mut specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + specs.push(TraceShoutBusSpec { + table_id: 1001, + ell_addr: 17, + n_vals: 1usize, + }); + let (bus_region_len, _) = rv32_trace_shared_bus_requirements_with_specs(&layout, &[3u32], &specs, &mem_layouts) + .expect("trace shared bus requirements with extra spec"); + layout.m += bus_region_len; + let cfg = rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &[3u32], + &specs, + mem_layouts, + HashMap::<(u32, u64), F>::new(), + ) + .expect("trace shared bus config with extra spec"); + + let base = cfg.shout_cpu.get(&3u32).expect("missing base shout table"); + assert!(base.is_empty(), "base table must stay padding-only"); + let custom = cfg + .shout_cpu + .get(&1001u32) + .expect("missing custom shout table"); + assert!(custom.is_empty(), "custom table must use padding-only bindings"); +} + +#[test] +fn rv32_trace_shared_bus_with_specs_rejects_conflicting_ell_addr() { + let layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let mut extra = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + extra.push(TraceShoutBusSpec { + table_id: 3, + ell_addr: 63, + n_vals: 1usize, + }); + let err = rv32_trace_shared_bus_requirements_with_specs(&layout, &[3u32], &extra, &mem_layouts) + .expect_err("conflicting table width must fail"); + assert!(err.contains("conflicting ell_addr"), "unexpected error: {err}"); +} + +#[test] +fn rv32_trace_shared_cpu_bus_config_with_specs_binds_decode_lookup_key_to_pc_before() { + let mut layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let table_ids = full_table_ids(); + let (bus_region_len, _) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &table_ids, &decode_specs, &mem_layouts) + .expect("trace shared bus requirements"); + layout.m += bus_region_len; + let cfg = rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &table_ids, + &decode_specs, + mem_layouts, + HashMap::<(u32, u64), F>::new(), + ) + .expect("trace shared bus config"); + + let decode = Rv32DecodeSidecarLayout::new(); + let table_id = rv32_decode_lookup_table_id_for_col(decode.rd_has_write); + let lanes = cfg + .shout_cpu + .get(&table_id) + .expect("missing decode shout_cpu binding"); + assert_eq!(lanes.len(), 1, "decode lookup should bind one shout lane"); + let lane = &lanes[0]; + assert_eq!( + lane.has_lookup, CPU_BUS_COL_DISABLED, + "decode lookup should use key-only linkage (selector disabled)" + ); + assert_eq!( + lane.val, CPU_BUS_COL_DISABLED, + "decode lookup should use key-only linkage (value disabled)" + ); + assert_eq!( + lane.addr, + Some(layout.cell(layout.trace.pc_before, 0)), + "decode lookup key must bind to committed pc_before" + ); +} + +#[test] +fn rv32_trace_shared_cpu_bus_config_with_specs_binds_width_lookup_key_to_cycle() { + let mut layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + let mem_layouts = sample_mem_layouts(); + let mut specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + specs.extend(width_selector_specs(/*cycle_d=*/ 8)); + let table_ids = full_table_ids(); + let (bus_region_len, _) = rv32_trace_shared_bus_requirements_with_specs(&layout, &table_ids, &specs, &mem_layouts) + .expect("trace shared bus requirements"); + layout.m += bus_region_len; + let cfg = rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &table_ids, + &specs, + mem_layouts, + HashMap::<(u32, u64), F>::new(), + ) + .expect("trace shared bus config"); + + let width = Rv32WidthSidecarLayout::new(); + let table_id = rv32_width_lookup_table_id_for_col(width.ram_rv_q16); + let lanes = cfg + .shout_cpu + .get(&table_id) + .expect("missing width shout_cpu binding"); + assert_eq!(lanes.len(), 1, "width lookup should bind one shout lane"); + let lane = &lanes[0]; + assert_eq!( + lane.has_lookup, CPU_BUS_COL_DISABLED, + "width lookup should use key-only linkage (selector disabled)" + ); + assert_eq!( + lane.val, CPU_BUS_COL_DISABLED, + "width lookup should use key-only linkage (value disabled)" + ); + assert_eq!( + lane.addr, + Some(layout.cell(layout.trace.cycle, 0)), + "width lookup key must bind to committed cycle" + ); +} + +#[test] +fn rv32_trace_lookup_addr_group_coalesces_all_decode_lookup_backed_tables() { + let decode = Rv32DecodeSidecarLayout::new(); + let cols = rv32_decode_lookup_backed_cols(&decode); + assert!(!cols.is_empty(), "decode lookup-backed set must be non-empty"); + + let mut groups = std::collections::BTreeSet::new(); + for col in cols { + let table_id = rv32_decode_lookup_table_id_for_col(col); + let group = rv32_trace_lookup_addr_group_for_table_id(table_id); + assert!(group.is_some(), "decode table_id={table_id} must have an addr group"); + groups.insert(group); + } + assert_eq!( + groups.len(), + 1, + "all decode lookup-backed tables should share one address group" + ); +} + +#[test] +fn rv32_trace_lookup_addr_group_coalesces_all_width_lookup_tables() { + let width = Rv32WidthSidecarLayout::new(); + let cols = rv32_width_lookup_backed_cols(&width); + assert!(!cols.is_empty(), "width lookup-backed set must be non-empty"); + + let mut groups = std::collections::BTreeSet::new(); + for col in cols { + let table_id = rv32_width_lookup_table_id_for_col(col); + let group = rv32_trace_lookup_addr_group_for_table_id(table_id); + assert!(group.is_some(), "width table_id={table_id} must have an addr group"); + groups.insert(group); + } + assert_eq!( + groups.len(), + 1, + "all width lookup-backed tables should share one address group" + ); +} + +#[test] +fn rv32_trace_lookup_selector_group_coalesces_all_decode_lookup_backed_tables() { + let decode = Rv32DecodeSidecarLayout::new(); + let cols = rv32_decode_lookup_backed_cols(&decode); + assert!(!cols.is_empty(), "decode lookup-backed set must be non-empty"); + + let mut groups = std::collections::BTreeSet::new(); + for col in cols { + let table_id = rv32_decode_lookup_table_id_for_col(col); + let group = rv32_trace_lookup_selector_group_for_table_id(table_id); + assert!(group.is_some(), "decode table_id={table_id} must have a selector group"); + groups.insert(group); + } + assert_eq!( + groups.len(), + 1, + "all decode lookup-backed tables should share one selector group" + ); +} + +#[test] +fn rv32_trace_lookup_selector_group_coalesces_all_width_lookup_tables() { + let width = Rv32WidthSidecarLayout::new(); + let cols = rv32_width_lookup_backed_cols(&width); + assert!(!cols.is_empty(), "width lookup-backed set must be non-empty"); + + let mut groups = std::collections::BTreeSet::new(); + for col in cols { + let table_id = rv32_width_lookup_table_id_for_col(col); + let group = rv32_trace_lookup_selector_group_for_table_id(table_id); + assert!(group.is_some(), "width table_id={table_id} must have a selector group"); + groups.insert(group); + } + assert_eq!( + groups.len(), + 1, + "all width lookup-backed tables should share one selector group" + ); +} diff --git a/crates/neo-memory/tests/riscv_trace_sidecar_extract.rs b/crates/neo-memory/tests/riscv_trace_sidecar_extract.rs new file mode 100644 index 00000000..97ed0e1b --- /dev/null +++ b/crates/neo-memory/tests/riscv_trace_sidecar_extract.rs @@ -0,0 +1,178 @@ +use std::collections::HashMap; + +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, RiscvOpcode, RiscvShoutTables, + PROG_ID, +}; +use neo_memory::riscv::trace::{extract_shout_lanes_over_time, extract_twist_lanes_over_time}; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +fn build_exec_table() -> (Rv32ExecTable, Vec) { + // Program: + // - ADDI x1, x0, 1 + // - SW x1, 0(x0) + // - LW x2, 0(x0) + // - HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 0, + }, + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let shout_table_ids = vec![RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add).0]; + (exec, shout_table_ids) +} + +#[test] +fn trace_sidecar_extract_smoke() { + let (exec, shout_table_ids) = build_exec_table(); + + // Keep it tiny: RAM addresses in this program are 0, so ell_addr=2 is enough. + let init_regs: HashMap = HashMap::new(); + let init_ram: HashMap = HashMap::new(); + let twist = + extract_twist_lanes_over_time(&exec, &init_regs, &init_ram, /*ram_ell_addr=*/ 2).expect("twist extract"); + let shout = extract_shout_lanes_over_time(&exec, &shout_table_ids).expect("shout extract"); + + assert_eq!(twist.prog.has_read.len(), exec.rows.len()); + assert_eq!(twist.reg_lane0.has_read.len(), exec.rows.len()); + assert_eq!(twist.reg_lane1.has_read.len(), exec.rows.len()); + assert_eq!(twist.ram.has_read.len(), exec.rows.len()); + assert_eq!(shout.len(), 1); + assert_eq!(shout[0].has_lookup.len(), exec.rows.len()); + + // Inactive tail must be all zero/false. + for (i, row) in exec.rows.iter().enumerate() { + if row.active { + continue; + } + assert!(!twist.prog.has_read[i]); + assert!(!twist.reg_lane0.has_read[i]); + assert!(!twist.reg_lane1.has_read[i]); + assert!(!twist.reg_lane0.has_write[i]); + assert!(!twist.ram.has_read[i]); + assert!(!twist.ram.has_write[i]); + assert_eq!(twist.reg_lane0.inc_at_write_addr[i], F::ZERO); + assert_eq!(twist.ram.inc_at_write_addr[i], F::ZERO); + assert!(!shout[0].has_lookup[i]); + assert_eq!(shout[0].key[i], 0); + assert_eq!(shout[0].value[i], 0); + } + + // Sanity: should contain at least one RAM write (SW) and one RAM read (LW). + assert!(twist.ram.has_write.iter().any(|&b| b), "expected a RAM write"); + assert!(twist.ram.has_read.iter().any(|&b| b), "expected a RAM read"); +} + +#[test] +fn trace_sidecar_extract_rejects_multiple_shout_events() { + let (mut exec, shout_table_ids) = build_exec_table(); + + let first_active = exec + .rows + .iter() + .position(|r| r.active) + .expect("must have active rows"); + let ev = neo_vm_trace::ShoutEvent:: { + shout_id: neo_vm_trace::ShoutId(shout_table_ids[0]), + key: 0, + value: 0, + }; + exec.rows[first_active].shout_events.push(ev.clone()); + exec.rows[first_active].shout_events.push(ev); + + let err = extract_shout_lanes_over_time(&exec, &shout_table_ids).unwrap_err(); + assert!(err.contains("multiple Shout events"), "{err}"); +} + +#[test] +fn trace_sidecar_extract_rejects_multiple_ram_writes() { + let (mut exec, _shout_table_ids) = build_exec_table(); + + let sw_row = exec + .rows + .iter() + .position(|r| { + r.ram_events + .iter() + .any(|e| matches!(e.kind, neo_vm_trace::TwistOpKind::Write)) + }) + .expect("expected a RAM write row"); + let write_ev = exec.rows[sw_row] + .ram_events + .iter() + .find(|e| matches!(e.kind, neo_vm_trace::TwistOpKind::Write)) + .cloned() + .expect("write event"); + exec.rows[sw_row].ram_events.push(write_ev); + + let init_regs: HashMap = HashMap::new(); + let init_ram: HashMap = HashMap::new(); + let err = extract_twist_lanes_over_time(&exec, &init_regs, &init_ram, /*ram_ell_addr=*/ 2).unwrap_err(); + assert!(err.contains("multiple RAM writes"), "{err}"); +} + +#[test] +fn trace_sidecar_extract_rejects_read_write_addr_mismatch() { + let (mut exec, _shout_table_ids) = build_exec_table(); + + let sw_row = exec + .rows + .iter() + .position(|r| { + r.ram_events + .iter() + .any(|e| matches!(e.kind, neo_vm_trace::TwistOpKind::Write)) + }) + .expect("expected a RAM write row"); + let write_ev = exec.rows[sw_row] + .ram_events + .iter() + .find(|e| matches!(e.kind, neo_vm_trace::TwistOpKind::Write)) + .cloned() + .expect("write event"); + + let mut read_ev = write_ev.clone(); + read_ev.kind = neo_vm_trace::TwistOpKind::Read; + read_ev.addr = write_ev.addr + 1; + exec.rows[sw_row].ram_events.push(read_ev); + + let init_regs: HashMap = HashMap::new(); + let init_ram: HashMap = HashMap::new(); + let err = extract_twist_lanes_over_time(&exec, &init_regs, &init_ram, /*ram_ell_addr=*/ 2).unwrap_err(); + assert!(err.contains("RAM read/write addr mismatch"), "{err}"); +} diff --git a/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs b/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs new file mode 100644 index 00000000..02b19274 --- /dev/null +++ b/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs @@ -0,0 +1,1589 @@ +use neo_ccs::relations::check_ccs_rowwise_zero; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, RAM_ID, +}; +use neo_vm_trace::trace_program; +use neo_vm_trace::Twist as _; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +#[test] +fn rv32_trace_layout_removes_fixed_shout_table_selector_lanes() { + let layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); + + assert_eq!( + layout.trace.cols, 21, + "trace width regression: expected 21 columns after shout_lhs/jalr_drop_bit hardening" + ); + assert_eq!( + layout.trace.cols, + layout.trace.jalr_drop_bit + 1, + "trace layout should remain densely packed" + ); +} + +#[test] +fn rv32_trace_wiring_ccs_satisfies_addi_halt() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + check_ccs_rowwise_zero(&ccs, &x, &w).expect("trace CCS satisfied"); +} + +#[test] +fn rv32_trace_wiring_ccs_satisfies_mixed_addi_andi_halt() { + // Program: ADDI x1, x0, 1; ANDI x2, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::And, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + check_ccs_rowwise_zero(&ccs, &x, &w).expect("trace CCS satisfied"); +} + +#[test] +fn rv32_trace_wiring_ccs_satisfies_mixed_addi_ori_halt() { + // Program: ADDI x1, x0, 1; ORI x2, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + check_ccs_rowwise_zero(&ccs, &x, &w).expect("trace CCS satisfied"); +} + +#[test] +fn rv32_trace_wiring_ccs_satisfies_addi_sw_lw_halt() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 0, + }, + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 32).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + check_ccs_rowwise_zero(&ccs, &x, &w).expect("trace CCS satisfied"); +} + +#[test] +fn rv32_trace_wiring_ccs_satisfies_lui_x0_halt() { + // Program: LUI x0, 1; HALT + // Architecturally this must be satisfiable (x0 discards the writeback). + let program = vec![RiscvInstruction::Lui { rd: 0, imm: 1 }, RiscvInstruction::Halt]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_ok(), + "LUI x0 must be satisfiable in trace CCS" + ); +} + +#[test] +fn rv32_trace_wiring_ccs_rejects_all_inactive_padding_witness() { + // Program: ADDI x1, x0, 1; HALT + // + // Red-team target: with no explicit execution anchor, an all-inactive witness can + // satisfy most gated constraints while only honoring public bindings + chains. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let t = exec.rows.len(); + let layout = Rv32TraceCcsLayout::new(t).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + let mut set = |col: usize, row: usize, value: F| { + let idx = layout.cell(col, row); + w[idx - layout.m_in] = value; + }; + + // Start from an all-zero trace region. + for col in 0..layout.trace.cols { + for row in 0..t { + set(col, row, F::ZERO); + } + } + + // Public bindings that must continue to hold. + set(layout.trace.pc_before, 0, x[layout.pc0]); + set(layout.trace.pc_after, t - 1, x[layout.pc_final]); + set(layout.trace.halted, 0, x[layout.halted_in]); + set(layout.trace.halted, t - 1, x[layout.halted_out]); + + // Keep cycle and pc chains valid. + for row in 0..t { + set(layout.trace.cycle, row, F::from_u64(row as u64)); + if row > 0 { + set(layout.trace.pc_before, row, F::ZERO); + } + if row < (t - 1) { + set(layout.trace.pc_after, row, F::ZERO); + } + } + + // Force all rows inactive. + for row in 0..t { + set(layout.trace.active, row, F::ZERO); + } + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "all-inactive witness should be rejected by an execution anchor" + ); +} + +#[test] +fn rv32_trace_wiring_ccs_rejects_trace_one_column_tamper() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Tamper trace-local `one` column; production CCS should reject this. + let one_idx = layout.cell(layout.trace.one, 0); + w[one_idx - layout.m_in] = F::ZERO; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered trace.one should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to control stage claim-only control-flow semantics"] +fn rv32_trace_wiring_ccs_rejects_jalr_misaligned_pc_after() { + // Program: + // ADDI x1, x0, 8 + // JALR x2, x1, 0 + // BEQ x0, x0, 4 + // HALT + // + // Red-team target: force row1.pc_after to an odd value while keeping + // pc_after + drop0 + 2*drop1 == rs1 + imm_i and global chaining intact. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 8, + }, + RiscvInstruction::Jalr { rd: 2, rs1: 1, imm: 0 }, + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 0, + rs2: 0, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (mut x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + let t = exec.rows.len(); + + // Shift JALR target chain by -1 from row1 forward (including padded tail), + // while preserving local control-flow equations and pc chaining. + for row in 1..t { + let idx = layout.cell(layout.trace.pc_after, row); + let new_pc_after = w[idx - layout.m_in] - F::ONE; + w[idx - layout.m_in] = new_pc_after; + } + for row in 2..t { + let pc_before_idx = layout.cell(layout.trace.pc_before, row); + let new_pc_before = w[pc_before_idx - layout.m_in] - F::ONE; + w[pc_before_idx - layout.m_in] = new_pc_before; + if exec.rows[row].active { + let prog_addr_idx = layout.cell(layout.trace.pc_before, row); + let new_prog_addr = w[prog_addr_idx - layout.m_in] - F::ONE; + w[prog_addr_idx - layout.m_in] = new_prog_addr; + } + } + + // Keep public pc_final consistent with the shifted tail. + x[layout.pc_final] -= F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "misaligned JALR pc_after should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_spurious_ram_addr_on_non_memory_row() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Row 0 is ALU (non-memory): keep RAM flags at 0 but inject a spurious address. + let row0_ram_addr = layout.cell(layout.trace.ram_addr, 0); + w[row0_ram_addr - layout.m_in] = F::from_u64(1234); + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "non-memory row with spurious ram_addr should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to shared-bus PROG/decode linkage semantics"] +fn rv32_trace_wiring_ccs_rejects_prog_value_tamper() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Flip PROG value for the first row (active row), which should violate + // active -> (prog_value == instr_word). + let prog_value_idx = layout.cell(layout.trace.instr_word, 0); + w[prog_value_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered witness should fail trace CCS" + ); +} + +#[test] +fn rv32_trace_wiring_ccs_rejects_halted_tail_pc_drift() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Red-team: keep continuity constraints satisfied but drift the halted tail PC. + // row2.pc_after += 1 and row3.pc_before += 1 preserves + // `pc_after[2] == pc_before[3]` while violating halted-tail quiescence. + let row2_pc_after_idx = layout.cell(layout.trace.pc_after, 2); + let row3_pc_before_idx = layout.cell(layout.trace.pc_before, 3); + w[row2_pc_after_idx - layout.m_in] += F::ONE; + w[row3_pc_before_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "halted-tail PC drift should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_halt_flag_mismatch_on_active_row() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Row 1 is HALT/system; forge halted=0. + let row1_halted_idx = layout.cell(layout.trace.halted, 1); + w[row1_halted_idx - layout.m_in] = F::ZERO; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "active HALT row with halted=0 must fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode-stage lookup semantics"] +fn rv32_trace_wiring_ccs_rejects_decode_bit_tamper() { + // Program: ADDI x1, x0, 1; HALT + // + // Target production behavior: opcode/decoded fields are semantically bound to instr_word. + // This test is expected to fail until trace semantics are enforced (not just wiring). + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Tamper a trace-local scalar on an active row. + let rs1_addr_idx = layout.cell(layout.trace.rs1_addr, 0); + w[rs1_addr_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered decode bit should not satisfy production-grade trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_lui_writeback_tamper() { + // Program: LUI x1, 1; HALT + // + // Target production behavior: rd_val/writeback must satisfy ISA semantics. + // This test is expected to fail until trace semantics are enforced. + let program = vec![RiscvInstruction::Lui { rd: 1, imm: 1 }, RiscvInstruction::Halt]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Tamper rd writeback value on the LUI row. + let rd_val_idx = layout.cell(layout.trace.rd_val, 0); + w[rd_val_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered LUI writeback should not satisfy production-grade trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_auipc_writeback_tamper() { + // Program: AUIPC x1, 1; HALT + let program = vec![RiscvInstruction::Auipc { rd: 1, imm: 1 }, RiscvInstruction::Halt]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Tamper rd writeback value on the AUIPC row. + let rd_val_idx = layout.cell(layout.trace.rd_val, 0); + w[rd_val_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered AUIPC writeback should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_jal_link_writeback_tamper() { + // Program: JAL x1, 8; ADDI x2, x0, 1; HALT + // Jump skips over ADDI; JAL link value should be pc_before + 4. + let program = vec![ + RiscvInstruction::Jal { rd: 1, imm: 8 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Tamper rd writeback value on the JAL row. + let rd_val_idx = layout.cell(layout.trace.rd_val, 0); + w[rd_val_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered JAL link writeback should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_jalr_link_writeback_tamper() { + // Program: + // ADDI x1, x0, 8 + // JALR x2, x1, 0 + // HALT + // JALR link value should be pc_before + 4 on row 1. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 8, + }, + RiscvInstruction::Jalr { rd: 2, rs1: 1, imm: 0 }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Tamper rd writeback value on the JALR row (row 1). + let rd_val_idx = layout.cell(layout.trace.rd_val, 1); + w[rd_val_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered JALR link writeback should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_non_branch_pc_update_tamper() { + // Program: ADDI x1, x0, 1; ADDI x2, x1, 2; HALT + // + // Target production behavior: non-branch rows must apply the correct PC update rule. + // This test is expected to fail until trace semantics are enforced. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Keep wiring equalities intact but drift a straight-line PC transition: + // row0.pc_after := row0.pc_after + 4 + // row1.pc_before := row1.pc_before + 4 + // row1.prog_addr := row1.prog_addr + 4 (to preserve active->prog_addr==pc_before) + let row0_pc_after_idx = layout.cell(layout.trace.pc_after, 0); + let row1_pc_before_idx = layout.cell(layout.trace.pc_before, 1); + let row1_prog_addr_idx = layout.cell(layout.trace.pc_before, 1); + let delta = F::from_u64(4); + w[row0_pc_after_idx - layout.m_in] += delta; + w[row1_pc_before_idx - layout.m_in] += delta; + w[row1_prog_addr_idx - layout.m_in] += delta; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered non-branch PC update should not satisfy production-grade trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_missing_writeback_on_addi() { + // Program: ADDI x1, x0, 1; HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Row 0 is ADDI with rd=1. Forge "no writeback" by clearing the write address/value. + let row0_rd_addr = layout.cell(layout.trace.rd_addr, 0); + let row0_rd_val = layout.cell(layout.trace.rd_val, 0); + w[row0_rd_addr - layout.m_in] = F::ZERO; + w[row0_rd_val - layout.m_in] = F::ZERO; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "ADDI row without required writeback must fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to width stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_load_without_ram_read() { + // Program: LW x1, 0(x0); HALT + let program = vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let mut twist = + RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + twist.store(RAM_ID, /*addr=*/ 0, /*value=*/ 7); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Tamper load row by clearing the read value. + let row0_ram_rv = layout.cell(layout.trace.ram_rv, 0); + w[row0_ram_rv - layout.m_in] = F::ZERO; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "load row without RAM read must fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to width stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_store_without_ram_write() { + // Program: ADDI x1, x0, 9; SW x1, 0(x0); HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 9, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Row 1 is SW. Clear write value. + let row1_ram_wv = layout.cell(layout.trace.ram_wv, 1); + w[row1_ram_wv - layout.m_in] = F::ZERO; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "store row without RAM write must fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to width stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_store_with_spurious_rd_writeback() { + // Program: ADDI x1, x0, 5; SW x1, 4(x0); HALT + // + // S-type encodes imm[4:0] in the "rd" field position. With imm=4 this field is non-zero, + // so a forged rd writeback can satisfy rd_addr==rd unless we enforce store no-writeback policy. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 5, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Row 1 is SW. Forge a writeback-like address/value. + let row1_rd_addr = layout.cell(layout.trace.rd_addr, 1); + let row1_rd_val = layout.cell(layout.trace.rd_val, 1); + w[row1_rd_addr - layout.m_in] = F::from_u64(4); + w[row1_rd_val - layout.m_in] = F::from_u64(9); + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "store row with forged rd writeback must fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_load_pc_update_tamper() { + // Program: LW x1, 0(x0); LW x2, 0(x0); HALT + let program = vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 2, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let mut twist = + RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + twist.store(RAM_ID, /*addr=*/ 0, /*value=*/ 13); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Preserve wiring equalities while drifting the first straight-line transition: + // row0.pc_after += 4, row1.pc_before += 4, row1.prog_addr += 4. + let row0_pc_after = layout.cell(layout.trace.pc_after, 0); + let row1_pc_before = layout.cell(layout.trace.pc_before, 1); + let row1_prog_addr = layout.cell(layout.trace.pc_before, 1); + let delta = F::from_u64(4); + w[row0_pc_after - layout.m_in] += delta; + w[row1_pc_before - layout.m_in] += delta; + w[row1_prog_addr - layout.m_in] += delta; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered load PC update should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_jal_pc_target_tamper() { + // Program: + // JAL x1, 8 + // ADDI x2, x0, 1 (skipped) + // BEQ x0, x0, 4 + // HALT + // Row 0 JAL target is pc_before + 8. + let program = vec![ + RiscvInstruction::Jal { rd: 1, imm: 8 }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 0, + rs2: 0, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Keep continuity and active->prog_addr intact while forging the JAL target. + // row0.pc_after += 4, row1.pc_before += 4, row1.prog_addr += 4. + // Row1 is a BRANCH control row, so existing non-control PC constraints do not catch this. + let row0_pc_after = layout.cell(layout.trace.pc_after, 0); + let row1_pc_before = layout.cell(layout.trace.pc_before, 1); + let row1_prog_addr = layout.cell(layout.trace.pc_before, 1); + let delta = F::from_u64(4); + w[row0_pc_after - layout.m_in] += delta; + w[row1_pc_before - layout.m_in] += delta; + w[row1_prog_addr - layout.m_in] += delta; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered JAL target should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_jalr_pc_target_tamper() { + // Program: + // ADDI x1, x0, 8 + // JALR x2, x1, 0 + // BEQ x0, x0, 4 + // HALT + // Row 1 JALR target is (rs1 + imm_i) masked to 4-byte alignment. + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 8, + }, + RiscvInstruction::Jalr { rd: 2, rs1: 1, imm: 0 }, + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 0, + rs2: 0, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Keep continuity and active->prog_addr intact while forging the JALR target. + // row1.pc_after += 4, row2.pc_before += 4, row2.prog_addr += 4. + // Row2 is a BRANCH control row, so existing non-control PC constraints do not catch this. + let row1_pc_after = layout.cell(layout.trace.pc_after, 1); + let row2_pc_before = layout.cell(layout.trace.pc_before, 2); + let row2_prog_addr = layout.cell(layout.trace.pc_before, 2); + let delta = F::from_u64(4); + w[row1_pc_after - layout.m_in] += delta; + w[row2_pc_before - layout.m_in] += delta; + w[row2_prog_addr - layout.m_in] += delta; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered JALR target should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_branch_target_tamper() { + // Program: + // BEQ x0, x0, 8 + // ADDI x1, x0, 1 (skipped) + // BEQ x0, x0, 4 + // HALT + // Row 0 BEQ is always taken, so pc_after must be pc_before + 8. + let program = vec![ + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 0, + rs2: 0, + imm: 8, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 0, + rs2: 0, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Keep continuity and active->prog_addr intact while forging the BRANCH target. + // row0.pc_after += 4, row1.pc_before += 4, row1.prog_addr += 4. + // Row1 is another BRANCH control row, so existing non-control PC constraints do not catch this. + let row0_pc_after = layout.cell(layout.trace.pc_after, 0); + let row1_pc_before = layout.cell(layout.trace.pc_before, 1); + let row1_prog_addr = layout.cell(layout.trace.pc_before, 1); + let delta = F::from_u64(4); + w[row0_pc_after - layout.m_in] += delta; + w[row1_pc_before - layout.m_in] += delta; + w[row1_prog_addr - layout.m_in] += delta; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered branch target should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_load_ram_addr_tamper() { + // Program: LW x1, 4(x0); HALT + let program = vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let mut twist = + RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + twist.store(RAM_ID, /*addr=*/ 4, /*value=*/ 0x1234); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Row 0 is LW. Forge RAM addr while preserving current wiring and class policy constraints. + let row0_ram_addr = layout.cell(layout.trace.ram_addr, 0); + w[row0_ram_addr - layout.m_in] += F::from_u64(4); + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered load ram_addr should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_store_ram_addr_tamper() { + // Program: ADDI x1, x0, 7; SW x1, 4(x0); HALT + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 7, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Row 1 is SW. Forge RAM addr while preserving current wiring and class policy constraints. + let row1_ram_addr = layout.cell(layout.trace.ram_addr, 1); + w[row1_ram_addr - layout.m_in] += F::from_u64(4); + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered store ram_addr should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_branch_condition_shout_tamper() { + // Program: BEQ x0, x0, 8; ADDI x1, x0, 1; HALT + // BEQ compares equal, so shout_val should drive taken=1. + let program = vec![ + RiscvInstruction::Branch { + cond: BranchCondition::Eq, + rs1: 0, + rs2: 0, + imm: 8, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + // Keep PC target intact but forge branch compare output on the branch row. + let row0_shout_val = layout.cell(layout.trace.shout_val, 0); + w[row0_shout_val - layout.m_in] = F::ZERO; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered branch compare output should fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_alu_value_binding_tamper() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 7, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + let rd_val_idx = layout.cell(layout.trace.rd_val, 0); + w[rd_val_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered ALU rd value must fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_shout_has_lookup_tamper() { + let program = vec![ + RiscvInstruction::Branch { + cond: BranchCondition::Ltu, + rs1: 0, + rs2: 0, + imm: 4, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + let has_lookup_idx = layout.cell(layout.trace.shout_has_lookup, 0); + w[has_lookup_idx - layout.m_in] = F::ZERO; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered branch shout has_lookup must fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to width stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_load_writeback_tamper_all_widths() { + let cases = [ + (RiscvMemOp::Lb, 0x0000_00FFu64, "LB"), + (RiscvMemOp::Lbu, 0x0000_00FFu64, "LBU"), + (RiscvMemOp::Lh, 0x0000_8001u64, "LH"), + (RiscvMemOp::Lhu, 0x0000_8001u64, "LHU"), + (RiscvMemOp::Lw, 0x1234_5678u64, "LW"), + ]; + + for (op, ram_value, name) in cases { + let program = vec![ + RiscvInstruction::Load { + op, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let mut twist = + RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + twist.store(RAM_ID, /*addr=*/ 0, /*value=*/ ram_value); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + let rd_val_idx = layout.cell(layout.trace.rd_val, 0); + w[rd_val_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered {name} writeback must fail trace CCS" + ); + } +} + +#[test] +#[ignore = "moved to width stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_sw_store_value_tamper() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 9, + }, + RiscvInstruction::Store { + op: RiscvMemOp::Sw, + rs1: 0, + rs2: 1, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let mut twist = + RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + twist.store(RAM_ID, /*addr=*/ 0, /*value=*/ 0xAABB_CCDD); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + let row_store = 1usize; + let ram_wv_idx = layout.cell(layout.trace.ram_wv, row_store); + w[ram_wv_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered SW store value must fail trace CCS" + ); +} + +#[test] +#[ignore = "moved to width stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_sb_sh_store_merge_tamper() { + let cases = [(RiscvMemOp::Sb, 0x12i32, "SB"), (RiscvMemOp::Sh, 0x123i32, "SH")]; + + for (op, imm, name) in cases { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm, + }, + RiscvInstruction::Store { + op, + rs1: 0, + rs2: 1, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let mut twist = + RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + twist.store(RAM_ID, /*addr=*/ 0, /*value=*/ 0xA1B2_C3D4); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, mut w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + let row_store = 1usize; + let ram_wv_idx = layout.cell(layout.trace.ram_wv, row_store); + w[ram_wv_idx - layout.m_in] += F::ONE; + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "tampered {name} merge store value must fail trace CCS" + ); + } +} + +#[test] +#[ignore = "moved to decode stage sidecar semantics"] +fn rv32_trace_wiring_ccs_rejects_rv32m_in_trace_scope() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 3, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 4, + }, + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 3, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_err(), + "RV32M must be rejected in Tier 2.1 trace scope" + ); +} + +#[test] +fn rv32_trace_wiring_ccs_allows_amo_when_scope_lock_is_sidecar_owned() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 5, + }, + RiscvInstruction::Amo { + op: RiscvMemOp::AmoaddW, + rd: 2, + rs1: 0, + rs2: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + let mut twist = + RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + twist.store(RAM_ID, /*addr=*/ 0, /*value=*/ 0x44); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, /*min_len=*/ 8).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + assert!( + check_ccs_rowwise_zero(&ccs, &x, &w).is_ok(), + "N0 CCS should accept AMO rows when the Tier 2.1 scope lock is sidecar-owned (WB/decode stage)" + ); +} diff --git a/crates/neo-memory/tests/rv32_trace_all_ccs_counts.rs b/crates/neo-memory/tests/rv32_trace_all_ccs_counts.rs new file mode 100644 index 00000000..6228cbeb --- /dev/null +++ b/crates/neo-memory/tests/rv32_trace_all_ccs_counts.rs @@ -0,0 +1,126 @@ +use std::collections::HashMap; + +use neo_memory::plain::PlainMemLayout; +use neo_memory::riscv::ccs::{ + build_rv32_trace_wiring_ccs, build_rv32_trace_wiring_ccs_with_reserved_rows, + rv32_trace_shared_bus_requirements_with_specs, Rv32TraceCcsLayout, TraceShoutBusSpec, +}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, RiscvCpu, RiscvInstruction, RiscvMemory, RiscvOpcode, RiscvShoutTables, PROG_ID, + RAM_ID, REG_ID, +}; +use neo_memory::riscv::trace::{rv32_decode_lookup_table_id_for_col, Rv32DecodeSidecarLayout}; +use neo_vm_trace::trace_program; + +fn sample_mem_layouts() -> HashMap { + HashMap::from([ + ( + PROG_ID.0, + PlainMemLayout { + k: 16, + d: 4, + n_side: 2, + lanes: 1, + }, + ), + ( + REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + ( + RAM_ID.0, + PlainMemLayout { + k: 16, + d: 4, + n_side: 2, + lanes: 1, + }, + ), + ]) +} + +fn decode_selector_specs(prog_d: usize) -> Vec { + let decode = Rv32DecodeSidecarLayout::new(); + [decode.rd_has_write, decode.ram_has_read, decode.ram_has_write] + .into_iter() + .map(|col| TraceShoutBusSpec { + table_id: rv32_decode_lookup_table_id_for_col(col), + ell_addr: prog_d, + n_vals: 1usize, + }) + .collect() +} + +fn trace_addi_halt_exec_table(min_len: usize) -> Rv32ExecTable { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + let decoded_program = decode_program(&program_bytes).expect("decode_program"); + + let mut cpu = RiscvCpu::new(/*xlen=*/ 32); + cpu.load_program(/*base=*/ 0, decoded_program); + + let twist = RiscvMemory::with_program_in_twist(/*xlen=*/ 32, PROG_ID, /*base_addr=*/ 0, &program_bytes); + let shout = RiscvShoutTables::new(/*xlen=*/ 32); + let trace = trace_program(cpu, twist, shout, /*max_steps=*/ 16).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, min_len).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().expect("cycle chain"); + exec.validate_pc_chain().expect("pc chain"); + exec.validate_halted_tail().expect("halted tail"); + exec.validate_inactive_rows_are_empty() + .expect("inactive rows"); + exec +} + +#[test] +fn rv32_trace_ccs_counts_follow_layout_shape() { + let exec = trace_addi_halt_exec_table(/*min_len=*/ 4); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + + assert_eq!( + layout.m, + layout.m_in + layout.trace.cols * layout.t, + "layout width must equal public + flattened trace region" + ); + assert_eq!(ccs.m, layout.m, "CCS witness width must match layout width"); + assert!(ccs.n > layout.t, "CCS should include transition + wiring constraints"); +} + +#[test] +fn rv32_trace_reserved_rows_increase_constraint_count_exactly() { + let exec = trace_addi_halt_exec_table(/*min_len=*/ 4); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let ccs_base = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS base"); + + let mem_layouts = sample_mem_layouts(); + let decode_specs = decode_selector_specs(mem_layouts[&PROG_ID.0].d); + let add_table_id = RiscvShoutTables::new(32).opcode_to_id(RiscvOpcode::Add).0; + let (_bus_region_len, reserved_rows) = + rv32_trace_shared_bus_requirements_with_specs(&layout, &[add_table_id], &decode_specs, &mem_layouts) + .expect("trace shared bus requirements"); + + let ccs_reserved = + build_rv32_trace_wiring_ccs_with_reserved_rows(&layout, reserved_rows).expect("trace CCS reserved"); + + assert!(reserved_rows > 0, "expected reserved rows from shared bus requirements"); + assert_eq!( + ccs_reserved.n, + ccs_base.n + reserved_rows, + "reserved rows must add directly to CCS row count" + ); +} diff --git a/crates/neo-memory/tests/shout_byte_decomp_semantics.rs b/crates/neo-memory/tests/shout_byte_decomp_semantics.rs index 1c19179d..de807d98 100644 --- a/crates/neo-memory/tests/shout_byte_decomp_semantics.rs +++ b/crates/neo-memory/tests/shout_byte_decomp_semantics.rs @@ -48,6 +48,7 @@ fn build_single_lane_explicit_lut_witness( } let inst = LutInstance::<(), F> { + table_id: 0, comms: Vec::new(), k: n_side, d: 1, @@ -57,6 +58,8 @@ fn build_single_lane_explicit_lut_witness( ell, table_spec: None, table: table.clone(), + addr_group: None, + selector_group: None, }; // Layout: [addr_bits(ell), has_lookup, val]. diff --git a/crates/neo-memory/tests/sparse_matrix_mle_correctness.rs b/crates/neo-memory/tests/sparse_matrix_mle_correctness.rs new file mode 100644 index 00000000..04087f69 --- /dev/null +++ b/crates/neo-memory/tests/sparse_matrix_mle_correctness.rs @@ -0,0 +1,70 @@ +use neo_math::K; +use neo_memory::mle::chi_at_index; +use neo_memory::sparse_matrix::{SparseMat, SparseMatEntry}; +use p3_field::PrimeCharacteristicRing; +use rand::Rng; +use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha8Rng; + +#[test] +fn sparse_matrix_mle_eval_by_folding_matches_dense() { + let mut rng = ChaCha8Rng::seed_from_u64(12345); + + let ell_row = 3usize; + let ell_col = 4usize; + let row_len = 1usize << ell_row; + let col_len = 1usize << ell_col; + + for _trial in 0..50usize { + // Random sparse entries with possible duplicates. + let nnz = rng.random_range(0..=40usize); + let mut entries: Vec> = Vec::with_capacity(nnz); + for _ in 0..nnz { + let row = rng.random_range(0..row_len) as u64; + let col = rng.random_range(0..col_len) as u64; + let value = K::from_u64(rng.random::()); + entries.push(SparseMatEntry { row, col, value }); + } + + // Dense materialization (with duplicate summation). + let mut dense = vec![K::ZERO; row_len * col_len]; + for e in entries.iter() { + dense[e.row as usize * col_len + e.col as usize] += e.value; + } + + let sparse = SparseMat::from_entries(ell_row, ell_col, entries); + + // Random evaluation point. + let r_row: Vec = (0..ell_row) + .map(|_| K::from_u64(rng.random::())) + .collect(); + let r_col: Vec = (0..ell_col) + .map(|_| K::from_u64(rng.random::())) + .collect(); + + // Dense MLE eval: Σ_{i,j} M[i,j] χ_r_row(i) χ_r_col(j). + let mut expected = K::ZERO; + for row in 0..row_len { + let chi_r = chi_at_index(&r_row, row); + if chi_r == K::ZERO { + continue; + } + for col in 0..col_len { + let v = dense[row * col_len + col]; + if v == K::ZERO { + continue; + } + expected += v * chi_r * chi_at_index(&r_col, col); + } + } + + let got = sparse + .mle_eval_by_folding(&r_row, &r_col) + .expect("mle_eval_by_folding"); + let got_direct = sparse + .mle_eval_direct(&r_row, &r_col) + .expect("mle_eval_direct"); + assert_eq!(got, expected); + assert_eq!(got_direct, expected); + } +} diff --git a/crates/neo-reductions/src/engines/optimized_engine/common.rs b/crates/neo-reductions/src/engines/optimized_engine/common.rs index 19c81021..ff925bc7 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/common.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/common.rs @@ -818,12 +818,12 @@ where /// - This helper performs only algebraic mixing over witnesses and outputs; it does not compute the /// combined commitment. The output `c` is copied from the first input as a placeholder. /// - Caller should set `out.c = Σ ρ_i · c_i` using the commitment module action if a commitment mix is required. -pub fn rlc_reduction_paper_exact( +fn rlc_reduction_paper_exact_from_refs( s: &CcsStructure, params: &NeoParams, rhos: &[Mat], me_inputs: &[MeInstance], - Zs: &[Mat], + Zs: &[&Mat], ell_d: usize, ) -> (MeInstance, Mat) where @@ -940,7 +940,7 @@ where // Z := Σ ρ_i Z_i let mut Z = Mat::zero(d, s.m, Ff::ZERO); for i in 0..k1 { - left_mul_acc(&mut Z, &rhos[i], &Zs[i]); + left_mul_acc(&mut Z, &rhos[i], Zs[i]); } let out = MeInstance:: { @@ -961,6 +961,22 @@ where (out, Z) } +pub fn rlc_reduction_paper_exact( + s: &CcsStructure, + params: &NeoParams, + rhos: &[Mat], + me_inputs: &[MeInstance], + Zs: &[Mat], + ell_d: usize, +) -> (MeInstance, Mat) +where + Ff: Field + PrimeCharacteristicRing + Copy + Send + Sync, + K: From, +{ + let z_refs: Vec<&Mat> = Zs.iter().collect(); + rlc_reduction_paper_exact_from_refs::(s, params, rhos, me_inputs, &z_refs, ell_d) +} + /// --- Π_RLC (optimized) ----------------------------------------------------- /// /// Optimized Random Linear Combination for the prover path. @@ -968,12 +984,12 @@ where /// Semantics match `rlc_reduction_paper_exact`, but this implementation: /// - Fast-paths the common `k=1` case (no mixing) to avoid a D×D by D×m multiply. /// - Uses cache-friendly row-major loops for the large witness matrix `Z` when k>1. -pub fn rlc_reduction_optimized( +fn rlc_reduction_optimized_from_refs( s: &CcsStructure, params: &NeoParams, rhos: &[Mat], me_inputs: &[MeInstance], - Zs: &[Mat], + Zs: &[&Mat], ell_d: usize, ) -> (MeInstance, Mat) where @@ -996,7 +1012,7 @@ where let mut out = me_inputs[0].clone(); out.y.truncate(t_core); out.y_scalars.truncate(t_core); - return (out, Zs[0].clone()); + return (out, (*Zs[0]).clone()); } let d = D; @@ -1112,30 +1128,63 @@ where let m = s.m; let z_out = Z.as_mut_slice(); const BLOCK_COLS: usize = 1024; - for i in 0..k1 { - let rho = &rhos[i]; - let zi = &Zs[i]; - debug_assert_eq!(rho.rows(), d); - debug_assert_eq!(rho.cols(), d); - debug_assert_eq!(zi.rows(), d); - debug_assert_eq!(zi.cols(), m); - - let rho_data = rho.as_slice(); - let z_in = zi.as_slice(); - - for col0 in (0..m).step_by(BLOCK_COLS) { - let len = core::cmp::min(BLOCK_COLS, m - col0); - for kk in 0..d { - let in_off = kk * m + col0; - for rr in 0..d { - let coeff = rho_data[rr * d + kk]; - if coeff == Ff::ZERO { - continue; + debug_assert_eq!(rhos[i].rows(), d); + debug_assert_eq!(rhos[i].cols(), d); + debug_assert_eq!(Zs[i].rows(), d); + debug_assert_eq!(Zs[i].cols(), m); + } + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + z_out + .par_chunks_exact_mut(m) + .enumerate() + .for_each(|(rr, row_out)| { + for col0 in (0..m).step_by(BLOCK_COLS) { + let len = core::cmp::min(BLOCK_COLS, m - col0); + for i in 0..k1 { + let rho_data = rhos[i].as_slice(); + let z_in = Zs[i].as_slice(); + for kk in 0..d { + let coeff = rho_data[rr * d + kk]; + if coeff == Ff::ZERO { + continue; + } + let in_off = kk * m + col0; + for t in 0..len { + row_out[col0 + t] += coeff * z_in[in_off + t]; + } + } } - let out_off = rr * m + col0; - for t in 0..len { - z_out[out_off + t] += coeff * z_in[in_off + t]; + } + }); + } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + for i in 0..k1 { + let rho = &rhos[i]; + let zi = Zs[i]; + debug_assert_eq!(rho.rows(), d); + debug_assert_eq!(rho.cols(), d); + debug_assert_eq!(zi.rows(), d); + debug_assert_eq!(zi.cols(), m); + + let rho_data = rho.as_slice(); + let z_in = zi.as_slice(); + + for col0 in (0..m).step_by(BLOCK_COLS) { + let len = core::cmp::min(BLOCK_COLS, m - col0); + for kk in 0..d { + let in_off = kk * m + col0; + for rr in 0..d { + let coeff = rho_data[rr * d + kk]; + if coeff == Ff::ZERO { + continue; + } + let out_off = rr * m + col0; + for t in 0..len { + z_out[out_off + t] += coeff * z_in[in_off + t]; + } } } } @@ -1161,6 +1210,22 @@ where (out, Z) } +pub fn rlc_reduction_optimized( + s: &CcsStructure, + params: &NeoParams, + rhos: &[Mat], + me_inputs: &[MeInstance], + Zs: &[Mat], + ell_d: usize, +) -> (MeInstance, Mat) +where + Ff: Field + PrimeCharacteristicRing + Copy + Send + Sync, + K: From, +{ + let z_refs: Vec<&Mat> = Zs.iter().collect(); + rlc_reduction_optimized_from_refs::(s, params, rhos, me_inputs, &z_refs, ell_d) +} + /// Same as `rlc_reduction_paper_exact`, but also computes the combined commitment via a caller-supplied /// mixing function over commitments. This matches the paper's Π_RLC output when `combine_commit` implements /// the correct S-module action on commitments. @@ -1186,6 +1251,28 @@ where (out, Z) } +/// Same as `rlc_reduction_optimized`, but computes the combined commitment via a caller-supplied +/// mixing function and accepts borrowed witness matrices (to avoid cloning large Z inputs). +pub fn rlc_reduction_optimized_with_commit_mix( + s: &CcsStructure, + params: &NeoParams, + rhos: &[Mat], + me_inputs: &[MeInstance], + Zs: &[&Mat], + ell_d: usize, + combine_commit: Comb, +) -> (MeInstance, Mat) +where + Ff: Field + PrimeCharacteristicRing + Copy + Send + Sync, + K: From, + Comb: Fn(&[Mat], &[Cmt]) -> Cmt, +{ + let (mut out, Z) = rlc_reduction_optimized_from_refs::(s, params, rhos, me_inputs, Zs, ell_d); + let inputs_c: Vec = me_inputs.iter().map(|m| m.c.clone()).collect(); + out.c = combine_commit(rhos, &inputs_c); + (out, Z) +} + /// --- Π_DEC (Section 4.6) --------------------------------------------------- /// /// Paper-exact decomposition: given parent ME(B,L) and a provided split Z = Σ b^i · Z_i, @@ -1363,8 +1450,12 @@ where for rho in 0..d { let mut acc = K::ZERO; - for c in 0..s.m { - acc += K::from(get_F(Zi, rho, c)) * vj[c]; + if rho < Zi.rows() { + let z_row = Zi.row(rho); + let cap = core::cmp::min(s.m, z_row.len()); + for c in 0..cap { + acc += K::from(z_row[c]) * vj[c]; + } } yij_pad[rho] = acc; } @@ -1385,8 +1476,12 @@ where let mut yz_pad = vec![K::ZERO; d_pad]; for rho in 0..d { let mut acc = K::ZERO; - for c in 0..s.m { - acc += K::from(get_F(Zi, rho, c)) * chi_s[c]; + if rho < Zi.rows() { + let z_row = Zi.row(rho); + let cap = core::cmp::min(s.m, z_row.len()); + for c in 0..cap { + acc += K::from(z_row[c]) * chi_s[c]; + } } yz_pad[rho] = acc; } diff --git a/crates/neo-reductions/src/engines/optimized_engine/mod.rs b/crates/neo-reductions/src/engines/optimized_engine/mod.rs index 4cf317b3..f4276945 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/mod.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/mod.rs @@ -48,6 +48,7 @@ pub use common::{ recomposed_z_from_Z, rlc_reduction_optimized, + rlc_reduction_optimized_with_commit_mix, // Paper-exact RLC/DEC rlc_reduction_paper_exact, rlc_reduction_paper_exact_with_commit_mix, diff --git a/crates/neo-reductions/src/engines/optimized_engine/oracle.rs b/crates/neo-reductions/src/engines/optimized_engine/oracle.rs index c569e7fc..dfc37631 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/oracle.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/oracle.rs @@ -135,9 +135,10 @@ where } let mut tbl = vec![[K::ZERO; D]; m_pad]; let cap = core::cmp::min(s.m, m_pad); - for col in 0..cap { - for rho in 0..D { - tbl[col][rho] = K::from(Zi[(rho, col)]); + for rho in 0..D { + let z_row = Zi.row(rho); + for col in 0..cap { + tbl[col][rho] = K::from(z_row[col]); } } digits_tables.push(tbl); @@ -205,7 +206,7 @@ where table.truncate(half); } - fn evals_col_phase(&self, xs: &[K]) -> Vec { + fn evals_col_phase_generic(&self, xs: &[K]) -> Vec { debug_assert!(self.cur_len >= 2 && self.cur_len % 2 == 0); let tail_len = self.cur_len / 2; let xs_len = xs.len(); @@ -316,6 +317,319 @@ where } } + fn evals_col_phase_b2(&self, xs: &[K]) -> Vec { + debug_assert!(self.cur_len >= 2 && self.cur_len % 2 == 0); + let tail_len = self.cur_len / 2; + if xs.is_empty() { + return Vec::new(); + } + + const PAR_THRESHOLD: usize = 1 << 13; + let three = K::from(F::from_u64(3)); + + let coeffs_seq = |tail_len: usize| -> [K; 5] { + let mut coeffs = [K::ZERO; 5]; + for t in 0..tail_len { + let idx = 2 * t; + let e0 = self.eq_beta_m_tbl[idx]; + let e1 = self.eq_beta_m_tbl[idx + 1] - e0; + + let mut inner = [K::ZERO; 4]; + for (wit_idx, tbl) in self.digits_tables.iter().enumerate() { + let lo = &tbl[idx]; + let hi = &tbl[idx + 1]; + let weights = &self.weights[wit_idx]; + + for rho in 0..D { + let w = weights[rho]; + let a = lo[rho]; + let b = hi[rho] - a; + + let a2 = a * a; + let a3 = a2 * a; + let b2 = b * b; + let b3 = b2 * b; + + // N(a+bX) = (a+bX)^3 - (a+bX) + let t0 = a3 - a; + let t1 = (a2 * b).scale_base_k(three) - b; + let t2 = (a * b2).scale_base_k(three); + let t3 = b3; + + inner[0] += w * t0; + inner[1] += w * t1; + inner[2] += w * t2; + inner[3] += w * t3; + } + } + + // (e0 + e1 X) * (inner0 + inner1 X + inner2 X^2 + inner3 X^3) + coeffs[0] += e0 * inner[0]; + coeffs[1] += e0 * inner[1] + e1 * inner[0]; + coeffs[2] += e0 * inner[2] + e1 * inner[1]; + coeffs[3] += e0 * inner[3] + e1 * inner[2]; + coeffs[4] += e1 * inner[3]; + } + coeffs + }; + + let coeffs = if tail_len >= PAR_THRESHOLD { + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + (0..tail_len) + .into_par_iter() + .fold( + || [K::ZERO; 5], + |mut coeffs, t| { + let idx = 2 * t; + let e0 = self.eq_beta_m_tbl[idx]; + let e1 = self.eq_beta_m_tbl[idx + 1] - e0; + + let mut inner = [K::ZERO; 4]; + for (wit_idx, tbl) in self.digits_tables.iter().enumerate() { + let lo = &tbl[idx]; + let hi = &tbl[idx + 1]; + let weights = &self.weights[wit_idx]; + + for rho in 0..D { + let w = weights[rho]; + let a = lo[rho]; + let b = hi[rho] - a; + + let a2 = a * a; + let a3 = a2 * a; + let b2 = b * b; + let b3 = b2 * b; + + let t0 = a3 - a; + let t1 = (a2 * b).scale_base_k(three) - b; + let t2 = (a * b2).scale_base_k(three); + let t3 = b3; + + inner[0] += w * t0; + inner[1] += w * t1; + inner[2] += w * t2; + inner[3] += w * t3; + } + } + + coeffs[0] += e0 * inner[0]; + coeffs[1] += e0 * inner[1] + e1 * inner[0]; + coeffs[2] += e0 * inner[2] + e1 * inner[1]; + coeffs[3] += e0 * inner[3] + e1 * inner[2]; + coeffs[4] += e1 * inner[3]; + coeffs + }, + ) + .reduce( + || [K::ZERO; 5], + |mut a, b| { + for i in 0..5 { + a[i] += b[i]; + } + a + }, + ) + } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + coeffs_seq(tail_len) + } + } else { + coeffs_seq(tail_len) + }; + + let xs_are_base = xs.iter().all(|&x| x.imag() == Fq::ZERO); + if xs_are_base { + xs.iter() + .map(|&x| crate::sumcheck::poly_eval_k_base(&coeffs, x.real())) + .collect() + } else { + xs.iter() + .map(|&x| crate::sumcheck::poly_eval_k(&coeffs, x)) + .collect() + } + } + + fn evals_col_phase_b3(&self, xs: &[K]) -> Vec { + debug_assert!(self.cur_len >= 2 && self.cur_len % 2 == 0); + let tail_len = self.cur_len / 2; + if xs.is_empty() { + return Vec::new(); + } + + const PAR_THRESHOLD: usize = 1 << 13; + let four = K::from(F::from_u64(4)); + let five = K::from(F::from_u64(5)); + let ten = K::from(F::from_u64(10)); + let fifteen = K::from(F::from_u64(15)); + + let coeffs_seq = |tail_len: usize| -> [K; 7] { + let mut coeffs = [K::ZERO; 7]; + for t in 0..tail_len { + let idx = 2 * t; + let e0 = self.eq_beta_m_tbl[idx]; + let e1 = self.eq_beta_m_tbl[idx + 1] - e0; + + let mut inner = [K::ZERO; 6]; + for (wit_idx, tbl) in self.digits_tables.iter().enumerate() { + let lo = &tbl[idx]; + let hi = &tbl[idx + 1]; + let weights = &self.weights[wit_idx]; + + for rho in 0..D { + let w = weights[rho]; + let a = lo[rho]; + let b = hi[rho] - a; + + let a2 = a * a; + let a3 = a2 * a; + let a4 = a2 * a2; + let a5 = a4 * a; + + let b2 = b * b; + let b3 = b2 * b; + let b4 = b2 * b2; + let b5 = b4 * b; + + // N(a+bX) = (a+bX)^5 - 5(a+bX)^3 + 4(a+bX) + let t0 = a5 - a3.scale_base_k(five) + a.scale_base_k(four); + let t1 = b * (a4.scale_base_k(five) - a2.scale_base_k(fifteen) + four); + let t2 = b2 * (a3.scale_base_k(ten) - a.scale_base_k(fifteen)); + let t3 = b3 * (a2.scale_base_k(ten) - five); + let t4 = b4 * a.scale_base_k(five); + let t5 = b5; + + inner[0] += w * t0; + inner[1] += w * t1; + inner[2] += w * t2; + inner[3] += w * t3; + inner[4] += w * t4; + inner[5] += w * t5; + } + } + + // (e0 + e1 X) * Σ_{k=0..5} inner[k] X^k + coeffs[0] += e0 * inner[0]; + coeffs[1] += e0 * inner[1] + e1 * inner[0]; + coeffs[2] += e0 * inner[2] + e1 * inner[1]; + coeffs[3] += e0 * inner[3] + e1 * inner[2]; + coeffs[4] += e0 * inner[4] + e1 * inner[3]; + coeffs[5] += e0 * inner[5] + e1 * inner[4]; + coeffs[6] += e1 * inner[5]; + } + coeffs + }; + + let coeffs = if tail_len >= PAR_THRESHOLD { + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + (0..tail_len) + .into_par_iter() + .fold( + || [K::ZERO; 7], + |mut coeffs, t| { + let idx = 2 * t; + let e0 = self.eq_beta_m_tbl[idx]; + let e1 = self.eq_beta_m_tbl[idx + 1] - e0; + + let mut inner = [K::ZERO; 6]; + for (wit_idx, tbl) in self.digits_tables.iter().enumerate() { + let lo = &tbl[idx]; + let hi = &tbl[idx + 1]; + let weights = &self.weights[wit_idx]; + + for rho in 0..D { + let w = weights[rho]; + let a = lo[rho]; + let b = hi[rho] - a; + + let a2 = a * a; + let a3 = a2 * a; + let a4 = a2 * a2; + let a5 = a4 * a; + + let b2 = b * b; + let b3 = b2 * b; + let b4 = b2 * b2; + let b5 = b4 * b; + + let t0 = a5 - a3.scale_base_k(five) + a.scale_base_k(four); + let t1 = b * (a4.scale_base_k(five) - a2.scale_base_k(fifteen) + four); + let t2 = b2 * (a3.scale_base_k(ten) - a.scale_base_k(fifteen)); + let t3 = b3 * (a2.scale_base_k(ten) - five); + let t4 = b4 * a.scale_base_k(five); + let t5 = b5; + + inner[0] += w * t0; + inner[1] += w * t1; + inner[2] += w * t2; + inner[3] += w * t3; + inner[4] += w * t4; + inner[5] += w * t5; + } + } + + coeffs[0] += e0 * inner[0]; + coeffs[1] += e0 * inner[1] + e1 * inner[0]; + coeffs[2] += e0 * inner[2] + e1 * inner[1]; + coeffs[3] += e0 * inner[3] + e1 * inner[2]; + coeffs[4] += e0 * inner[4] + e1 * inner[3]; + coeffs[5] += e0 * inner[5] + e1 * inner[4]; + coeffs[6] += e1 * inner[5]; + coeffs + }, + ) + .reduce( + || [K::ZERO; 7], + |mut a, b| { + for i in 0..7 { + a[i] += b[i]; + } + a + }, + ) + } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + coeffs_seq(tail_len) + } + } else { + coeffs_seq(tail_len) + }; + + let xs_are_base = xs.iter().all(|&x| x.imag() == Fq::ZERO); + if xs_are_base { + xs.iter() + .map(|&x| crate::sumcheck::poly_eval_k_base(&coeffs, x.real())) + .collect() + } else { + xs.iter() + .map(|&x| crate::sumcheck::poly_eval_k(&coeffs, x)) + .collect() + } + } + + fn evals_col_phase(&self, xs: &[K]) -> Vec { + match self.params.b { + 2 => self.evals_col_phase_b2(xs), + 3 => self.evals_col_phase_b3(xs), + _ => self.evals_col_phase_generic(xs), + } + } + + #[doc(hidden)] + pub fn __test_col_phase_fast_vs_generic(&self, xs: &[K]) -> Option<(Vec, Vec)> { + if self.round_idx >= self.ell_m { + return None; + } + match self.params.b { + 2 => Some((self.evals_col_phase_b2(xs), self.evals_col_phase_generic(xs))), + 3 => Some((self.evals_col_phase_b3(xs), self.evals_col_phase_generic(xs))), + _ => None, + } + } + fn evals_ajtai_phase(&self, xs: &[K]) -> Vec { let j = self.round_idx - self.ell_m; debug_assert!(j < self.ell_d, "NC Ajtai phase after all Ajtai bits"); @@ -438,25 +752,12 @@ struct RowStreamState { /// Compiled sparse polynomial terms for `f` using `f_var_tables` indices. f_terms: Vec, - /// NC tables: for each witness i, a row-domain table where each entry holds Ajtai digits - /// `[Z_i[0,row], ..., Z_i[D-1,row]]`. - nc_tables: Vec>, - - /// Precomputed products `w_beta_a[rho] * gamma^(i+1)` flattened as `i*D + rho`. - /// - /// This lets NC accumulation multiply each polynomial coefficient once (by `w*gamma`) and - /// avoids a second pass that would otherwise scale by `w_beta_a` per coefficient. - w_gamma_nc: Vec, - /// Combined Eval block table over rows (already summed over α' and (i,j) coefficients). /// When present, Eval contribution is: `eq_r_inputs(r') * gamma_to_k * eval_tbl(r')`. eval_tbl: Option>, gamma_to_k: K, b: u32, - /// Precomputed squares t^2 for the symmetric range polynomial factors (t=1..b-1). - range_t_sq: Vec, - /// True if all streamed tables are still in the base-field embedding (imag=0). /// /// When this holds and evaluation points are also base-field, we can evaluate the hot @@ -473,7 +774,6 @@ impl RowStreamState { ell_n: usize, mcs_witnesses: &[McsWitness], me_witnesses: &[Mat], - include_nc: bool, r_inputs: Option<&[K]>, sparse: &SparseCache, ) -> Self @@ -568,7 +868,7 @@ impl RowStreamState { .collect(); let k_total = all_witnesses.len(); - // NC: w_beta_a weights and gamma coefficients. + // Sanity: challenge vectors for Ajtai rounds must match ell_d. if ch.beta_a.len() != ell_d || ch.alpha.len() != ell_d { panic!( "Challenge length mismatch: alpha.len()={}, beta_a.len()={}, ell_d={ell_d}", @@ -576,24 +876,6 @@ impl RowStreamState { ch.beta_a.len() ); } - let w_gamma_nc = if include_nc { - let w_beta_a: Vec = (0..D) - .map(|rho| eq_points_bool_mask(rho, &ch.beta_a)) - .collect(); - let mut w_gamma_nc = vec![K::ZERO; k_total * D]; - let mut g = ch.gamma; - for i in 0..k_total { - let base = i * D; - for rho in 0..D { - w_gamma_nc[base + rho] = w_beta_a[rho] * g; - } - g *= ch.gamma; - } - w_gamma_nc - } else { - Vec::new() - }; - // Build z1 = recomposition of Z_1 (first MCS witness). let mut z1: Vec = vec![K::ZERO; s.m]; { @@ -671,35 +953,6 @@ impl RowStreamState { f_var_tables.push(out); } - // NC tables: digits at each boolean row index (legacy, only needed when NC is enabled). - let nc_tables = if include_nc { - let mut nc_tables: Vec> = Vec::with_capacity(k_total); - for &Zi in &all_witnesses { - if Zi.rows() != D || Zi.cols() != s.m { - panic!( - "Z shape mismatch: expected {}×{}, got {}×{}", - D, - s.m, - Zi.rows(), - Zi.cols() - ); - } - let mut tbl = vec![[K::ZERO; D]; n_pad]; - // Legacy NC table indexes witness columns by row-domain boolean index `row`. - // This is only sound under the identity-first square normal form (m == n). - let cap = core::cmp::min(core::cmp::min(n_eff, n_pad), s.m); - for row in 0..cap { - for rho in 0..D { - tbl[row][rho] = K::from(Zi[(rho, row)]); - } - } - nc_tables.push(tbl); - } - nc_tables - } else { - Vec::new() - }; - // Eval table (optional): only when both (a) there are carried witnesses, and (b) r_inputs exist. let mut gamma_to_k = K::ONE; for _ in 0..k_total { @@ -791,15 +1044,6 @@ impl RowStreamState { None }; - let mut range_t_sq = Vec::new(); - if include_nc && b > 1 { - range_t_sq.reserve((b - 1) as usize); - for t in 1..(b as i64) { - let tt = Ff::from_i64(t); - range_t_sq.push(K::from(tt * tt)); - } - } - Self { cur_len: n_pad, eq_beta_r_tbl, @@ -807,12 +1051,9 @@ impl RowStreamState { z1, f_var_tables, f_terms, - nc_tables, - w_gamma_nc, eval_tbl, gamma_to_k, b, - range_t_sq, all_base, } } @@ -829,21 +1070,6 @@ impl RowStreamState { table.truncate(half); } - #[inline] - fn fold_digits_table_inplace(table: &mut Vec<[K; D]>, r: K) { - debug_assert!(table.len() >= 2 && table.len() % 2 == 0); - let half = table.len() / 2; - for i in 0..half { - let base = 2 * i; - for rho in 0..D { - let lo = table[base][rho]; - let hi = table[base + 1][rho]; - table[i][rho] = lo + (hi - lo) * r; - } - } - table.truncate(half); - } - #[inline] fn fold_table_inplace_base(table: &mut Vec, r: Fq) { debug_assert!(table.len() >= 2 && table.len() % 2 == 0); @@ -856,21 +1082,6 @@ impl RowStreamState { table.truncate(half); } - #[inline] - fn fold_digits_table_inplace_base(table: &mut Vec<[K; D]>, r: Fq) { - debug_assert!(table.len() >= 2 && table.len() % 2 == 0); - let half = table.len() / 2; - for i in 0..half { - let base = 2 * i; - for rho in 0..D { - let lo = table[base][rho].real(); - let hi = table[base + 1][rho].real(); - table[i][rho] = K::from(lo + (hi - lo) * r); - } - } - table.truncate(half); - } - fn fold_inplace(&mut self, r: K) { if self.all_base && r.imag() == Fq::ZERO { let r0 = r.real(); @@ -881,9 +1092,6 @@ impl RowStreamState { for tbl in self.f_var_tables.iter_mut() { Self::fold_table_inplace_base(tbl, r0); } - for tbl in self.nc_tables.iter_mut() { - Self::fold_digits_table_inplace_base(tbl, r0); - } if let Some(tbl) = self.eval_tbl.as_mut() { Self::fold_table_inplace_base(tbl, r0); } @@ -896,9 +1104,6 @@ impl RowStreamState { for tbl in self.f_var_tables.iter_mut() { Self::fold_table_inplace(tbl, r); } - for tbl in self.nc_tables.iter_mut() { - Self::fold_digits_table_inplace(tbl, r); - } if let Some(tbl) = self.eval_tbl.as_mut() { Self::fold_table_inplace(tbl, r); } @@ -907,11 +1112,11 @@ impl RowStreamState { } #[inline] - fn poly_mul_affine_inplace_base(poly: &mut [Fq], a: Fq, b: Fq) { + fn poly_mul_affine_inplace_base(poly: &mut [Fq], a: Fq, b: Fq, current_deg: usize) { // Coeffs are low→high. Output truncates to input length: // new[0] = a*old[0]; new[d] = a*old[d] + b*old[d-1] (d>=1). let mut prev = Fq::ZERO; - for coeff in poly.iter_mut() { + for coeff in poly.iter_mut().take(current_deg + 2) { let old = *coeff; *coeff = a * old + b * prev; prev = old; @@ -932,7 +1137,6 @@ impl RowStreamState { fn evals_row_phase_b2_base(&self, tail_len: usize, xs: &[K]) -> Vec { let xs_base: Vec = xs.iter().map(|&x| x.real()).collect(); - let three = Fq::from_u64(3); let f_max_term_deg: usize = self .f_terms @@ -945,8 +1149,8 @@ impl RowStreamState { }) .max() .unwrap_or(0); - // NC contributes degree 4 after multiplying by eq_beta_r(X). - let deg_max = core::cmp::max(4, f_max_term_deg + 1); + // eq_beta_r(X) adds one degree; Eval block is quadratic. + let deg_max = core::cmp::max(2, f_max_term_deg + 1); const PAR_THRESHOLD: usize = 1 << 14; let coeffs_seq = |tail_len: usize| -> Vec { @@ -964,59 +1168,21 @@ impl RowStreamState { for term in &self.f_terms { term_poly.fill(Fq::ZERO); term_poly[0] = term.coeff.real(); + let mut current_deg = 0usize; for &(var_pos, exp) in &term.vars { - if exp == 0 { - continue; - } let tbl = &self.f_var_tables[var_pos]; let a = tbl[idx].real(); let b = tbl[idx + 1].real() - a; for _ in 0..exp { - Self::poly_mul_affine_inplace_base(&mut term_poly, a, b); + Self::poly_mul_affine_inplace_base(&mut term_poly, a, b, current_deg); + current_deg += 1; } } - for i in 0..=deg_max { + for i in 0..=core::cmp::min(current_deg, deg_max) { inner[i] += term_poly[i]; } } - // NC: degree-3 in X before multiplying by eq_beta_r(X). - // - // Accumulate into locals to avoid repeated loads/stores to `inner[*]` in the hot rho loop. - let mut nc0 = Fq::ZERO; - let mut nc1 = Fq::ZERO; - let mut nc2 = Fq::ZERO; - let mut nc3 = Fq::ZERO; - for w_i in 0..self.nc_tables.len() { - let lo = &self.nc_tables[w_i][idx]; - let hi = &self.nc_tables[w_i][idx + 1]; - let base = w_i * D; - for rho in 0..D { - let wg = self.w_gamma_nc[base + rho].real(); - let a = lo[rho].real(); - let b = hi[rho].real() - a; - - let a2 = a * a; - let a3 = a2 * a; - let b2 = b * b; - let b3 = b2 * b; - - let t0 = a3 - a; - let t1 = (a2 * b) * three - b; - let t2 = (a * b2) * three; - let t3 = b3; - - nc0 += wg * t0; - nc1 += wg * t1; - nc2 += wg * t2; - nc3 += wg * t3; - } - } - inner[0] += nc0; - inner[1] += nc1; - inner[2] += nc2; - inner[3] += nc3; - coeffs[0] += e0 * inner[0]; for d in 1..=deg_max { coeffs[d] += (e0 * inner[d]) + (e1 * inner[d - 1]); @@ -1064,64 +1230,22 @@ impl RowStreamState { for term in &self.f_terms { term_poly.fill(Fq::ZERO); term_poly[0] = term.coeff.real(); + let mut current_deg = 0usize; for &(var_pos, exp) in &term.vars { - if exp == 0 { - continue; - } let tbl = &self.f_var_tables[var_pos]; // v(X) = a + b·X let a = tbl[idx].real(); let b = tbl[idx + 1].real() - a; for _ in 0..exp { - Self::poly_mul_affine_inplace_base(&mut term_poly, a, b); + Self::poly_mul_affine_inplace_base(&mut term_poly, a, b, current_deg); + current_deg += 1; } } - for i in 0..=deg_max { + for i in 0..=core::cmp::min(current_deg, deg_max) { inner[i] += term_poly[i]; } } - // NC: degree-3 in X before multiplying by eq_beta_r(X). - // - // Accumulate into locals to avoid repeated loads/stores to `inner[*]` in the hot rho loop. - let mut nc0 = Fq::ZERO; - let mut nc1 = Fq::ZERO; - let mut nc2 = Fq::ZERO; - let mut nc3 = Fq::ZERO; - for w_i in 0..self.nc_tables.len() { - let lo = &self.nc_tables[w_i][idx]; - let hi = &self.nc_tables[w_i][idx + 1]; - let base = w_i * D; - for rho in 0..D { - let wg = self.w_gamma_nc[base + rho].real(); - - // y(X) = a + b·X - let a = lo[rho].real(); - let b = hi[rho].real() - a; - - // For b=2, N(y) = y(y^2-1) = y^3 - y. - // (a + bX)^3 - (a + bX) - let a2 = a * a; - let a3 = a2 * a; - let b2 = b * b; - let b3 = b2 * b; - - let t0 = a3 - a; - let t1 = (a2 * b) * three - b; - let t2 = (a * b2) * three; - let t3 = b3; - - nc0 += wg * t0; - nc1 += wg * t1; - nc2 += wg * t2; - nc3 += wg * t3; - } - } - inner[0] += nc0; - inner[1] += nc1; - inner[2] += nc2; - inner[3] += nc3; - // coeffs += eq_beta_r(X) * inner(X) coeffs[0] += e0 * inner[0]; for d in 1..=deg_max { @@ -1173,10 +1297,6 @@ impl RowStreamState { fn evals_row_phase_b3_base(&self, tail_len: usize, xs: &[K]) -> Vec { let xs_base: Vec = xs.iter().map(|&x| x.real()).collect(); - let four = Fq::from_u64(4); - let five = Fq::from_u64(5); - let ten = Fq::from_u64(10); - let fifteen = Fq::from_u64(15); let f_max_term_deg: usize = self .f_terms @@ -1189,8 +1309,8 @@ impl RowStreamState { }) .max() .unwrap_or(0); - // NC contributes degree 6 after multiplying by eq_beta_r(X). - let deg_max = core::cmp::max(6, f_max_term_deg + 1); + // eq_beta_r(X) adds one degree; Eval block is quadratic. + let deg_max = core::cmp::max(2, f_max_term_deg + 1); const PAR_THRESHOLD: usize = 1 << 14; let coeffs_seq = |tail_len: usize| -> Vec { @@ -1208,79 +1328,21 @@ impl RowStreamState { for term in &self.f_terms { term_poly.fill(Fq::ZERO); term_poly[0] = term.coeff.real(); + let mut current_deg = 0usize; for &(var_pos, exp) in &term.vars { - if exp == 0 { - continue; - } let tbl = &self.f_var_tables[var_pos]; let a = tbl[idx].real(); let b = tbl[idx + 1].real() - a; for _ in 0..exp { - Self::poly_mul_affine_inplace_base(&mut term_poly, a, b); + Self::poly_mul_affine_inplace_base(&mut term_poly, a, b, current_deg); + current_deg += 1; } } - for i in 0..=deg_max { + for i in 0..=core::cmp::min(current_deg, deg_max) { inner[i] += term_poly[i]; } } - // NC: degree-5 in X before multiplying by eq_beta_r(X). - // - // Accumulate into locals to avoid repeated loads/stores to `inner[*]` in the hot rho loop. - let mut nc0 = Fq::ZERO; - let mut nc1 = Fq::ZERO; - let mut nc2 = Fq::ZERO; - let mut nc3 = Fq::ZERO; - let mut nc4 = Fq::ZERO; - let mut nc5 = Fq::ZERO; - for w_i in 0..self.nc_tables.len() { - let lo = &self.nc_tables[w_i][idx]; - let hi = &self.nc_tables[w_i][idx + 1]; - let base = w_i * D; - - for rho in 0..D { - let wg = self.w_gamma_nc[base + rho].real(); - let a = lo[rho].real(); - let b = hi[rho].real() - a; - - let a2 = a * a; - let a3 = a2 * a; - let a4 = a2 * a2; - let a5 = a4 * a; - - let b2 = b * b; - let b3 = b2 * b; - let b4 = b2 * b2; - let b5 = b4 * b; - - let t0 = a5 - a3 * five + a * four; - let p1 = a4 * five - a2 * fifteen + four; - let t1 = b * p1; - - let p2 = a3 * ten - a * fifteen; - let t2 = b2 * p2; - - let p3 = a2 * ten - five; - let t3 = b3 * p3; - - let t4 = b4 * (a * five); - let t5 = b5; - - nc0 += wg * t0; - nc1 += wg * t1; - nc2 += wg * t2; - nc3 += wg * t3; - nc4 += wg * t4; - nc5 += wg * t5; - } - } - inner[0] += nc0; - inner[1] += nc1; - inner[2] += nc2; - inner[3] += nc3; - inner[4] += nc4; - inner[5] += nc5; - coeffs[0] += e0 * inner[0]; for d in 1..=deg_max { coeffs[d] += (e0 * inner[d]) + (e1 * inner[d - 1]); @@ -1328,91 +1390,22 @@ impl RowStreamState { for term in &self.f_terms { term_poly.fill(Fq::ZERO); term_poly[0] = term.coeff.real(); + let mut current_deg = 0usize; for &(var_pos, exp) in &term.vars { - if exp == 0 { - continue; - } let tbl = &self.f_var_tables[var_pos]; // v(X) = a + b·X let a = tbl[idx].real(); let b = tbl[idx + 1].real() - a; for _ in 0..exp { - Self::poly_mul_affine_inplace_base(&mut term_poly, a, b); + Self::poly_mul_affine_inplace_base(&mut term_poly, a, b, current_deg); + current_deg += 1; } } - for i in 0..=deg_max { + for i in 0..=core::cmp::min(current_deg, deg_max) { inner[i] += term_poly[i]; } } - // NC: degree-5 in X before multiplying by eq_beta_r(X). - // - // Accumulate into locals to avoid repeated loads/stores to `inner[*]` in the hot rho loop. - let mut nc0 = Fq::ZERO; - let mut nc1 = Fq::ZERO; - let mut nc2 = Fq::ZERO; - let mut nc3 = Fq::ZERO; - let mut nc4 = Fq::ZERO; - let mut nc5 = Fq::ZERO; - for w_i in 0..self.nc_tables.len() { - let lo = &self.nc_tables[w_i][idx]; - let hi = &self.nc_tables[w_i][idx + 1]; - let base = w_i * D; - - for rho in 0..D { - let wg = self.w_gamma_nc[base + rho].real(); - - // y(X) = a + b·X - let a = lo[rho].real(); - let b = hi[rho].real() - a; - - // Expand N(a + bX) = (a+bX)^5 - 5(a+bX)^3 + 4(a+bX). - // - // Coeffs: - // X^0: a^5 - 5a^3 + 4a - // X^1: b(5a^4 - 15a^2 + 4) - // X^2: b^2(10a^3 - 15a) - // X^3: b^3(10a^2 - 5) - // X^4: 5ab^4 - // X^5: b^5 - let a2 = a * a; - let a3 = a2 * a; - let a4 = a2 * a2; - let a5 = a4 * a; - - let b2 = b * b; - let b3 = b2 * b; - let b4 = b2 * b2; - let b5 = b4 * b; - - let t0 = a5 - a3 * five + a * four; - let p1 = a4 * five - a2 * fifteen + four; - let t1 = b * p1; - - let p2 = a3 * ten - a * fifteen; - let t2 = b2 * p2; - - let p3 = a2 * ten - five; - let t3 = b3 * p3; - - let t4 = b4 * (a * five); - let t5 = b5; - - nc0 += wg * t0; - nc1 += wg * t1; - nc2 += wg * t2; - nc3 += wg * t3; - nc4 += wg * t4; - nc5 += wg * t5; - } - } - inner[0] += nc0; - inner[1] += nc1; - inner[2] += nc2; - inner[3] += nc3; - inner[4] += nc4; - inner[5] += nc5; - // coeffs += eq_beta_r(X) * inner(X) coeffs[0] += e0 * inner[0]; for d in 1..=deg_max { @@ -1466,9 +1459,9 @@ impl RowStreamState { /// /// Coefficients are in low→high order. Output is truncated to the input length. #[inline] - fn poly_mul_affine_inplace(poly: &mut [K], a: K, b: K) { + fn poly_mul_affine_inplace(poly: &mut [K], a: K, b: K, current_deg: usize) { let mut prev = K::ZERO; - for coeff in poly.iter_mut() { + for coeff in poly.iter_mut().take(current_deg + 2) { let old = *coeff; *coeff = a * old + b * prev; prev = old; @@ -1492,8 +1485,6 @@ impl RowStreamState { return self.evals_row_phase_b2_base(tail_len, xs); } - let three = K::from(Ff::from_u64(3)); - let f_max_term_deg: usize = self .f_terms .iter() @@ -1505,8 +1496,8 @@ impl RowStreamState { }) .max() .unwrap_or(0); - // NC contributes degree 4 after multiplying by eq_beta_r(X). - let deg_max = core::cmp::max(4, f_max_term_deg + 1); + // eq_beta_r(X) adds one degree; Eval block is quadratic. + let deg_max = core::cmp::max(2, f_max_term_deg + 1); let mut coeffs = vec![K::ZERO; deg_max + 1]; let mut inner = vec![K::ZERO; deg_max + 1]; @@ -1517,65 +1508,29 @@ impl RowStreamState { let e0 = self.eq_beta_r_tbl[2 * t]; let e1 = self.eq_beta_r_tbl[2 * t + 1] - e0; - // inner(X) = f_prime(X) + nc_total(X) + // inner(X) = f_prime(X) inner.fill(K::ZERO); // f_prime(X): expand sparse polynomial with affine substitutions. for term in &self.f_terms { term_poly.fill(K::ZERO); term_poly[0] = term.coeff; + let mut current_deg = 0usize; for &(var_pos, exp) in &term.vars { - if exp == 0 { - continue; - } let tbl = &self.f_var_tables[var_pos]; // v(X) = a + b·X let a = tbl[2 * t]; let b = tbl[2 * t + 1] - a; for _ in 0..exp { - Self::poly_mul_affine_inplace(&mut term_poly, a, b); + Self::poly_mul_affine_inplace(&mut term_poly, a, b, current_deg); + current_deg += 1; } } - for i in 0..=deg_max { + for i in 0..=core::cmp::min(current_deg, deg_max) { inner[i] += term_poly[i]; } } - // NC: degree-3 in X before multiplying by eq_beta_r(X). - for w_i in 0..self.nc_tables.len() { - let base = w_i * D; - for rho in 0..D { - let wg = self.w_gamma_nc[base + rho]; - - // y(X) = a + b·X - let a = self.nc_tables[w_i][2 * t][rho]; - let b = self.nc_tables[w_i][2 * t + 1][rho] - a; - - // For b=2, N(y) = y(y^2-1) = y^3 - y. - let a2 = a * a; - let a3 = a2 * a; - let b2 = b * b; - let b3 = b2 * b; - - // (a + bX)^3 - (a + bX) - let t0 = a3 - a; - let t1 = (a2 * b).scale_base_k(three) - b; - let t2 = (a * b2).scale_base_k(three); - let t3 = b3; - - inner[0] += wg * t0; - if deg_max >= 1 { - inner[1] += wg * t1; - } - if deg_max >= 2 { - inner[2] += wg * t2; - } - if deg_max >= 3 { - inner[3] += wg * t3; - } - } - } - // coeffs += eq_beta_r(X) * inner(X) coeffs[0] += e0 * inner[0]; for d in 1..=deg_max { @@ -1619,11 +1574,6 @@ impl RowStreamState { return self.evals_row_phase_b3_base(tail_len, xs); } - let four = K::from(Ff::from_u64(4)); - let five = K::from(Ff::from_u64(5)); - let ten = K::from(Ff::from_u64(10)); - let fifteen = K::from(Ff::from_u64(15)); - let f_max_term_deg: usize = self .f_terms .iter() @@ -1635,8 +1585,8 @@ impl RowStreamState { }) .max() .unwrap_or(0); - // NC contributes degree 6 after multiplying by eq_beta_r(X). - let deg_max = core::cmp::max(6, f_max_term_deg + 1); + // eq_beta_r(X) adds one degree; Eval block is quadratic. + let deg_max = core::cmp::max(2, f_max_term_deg + 1); let coeffs = { #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] @@ -1656,84 +1606,29 @@ impl RowStreamState { let e0 = self.eq_beta_r_tbl[2 * t]; let e1 = self.eq_beta_r_tbl[2 * t + 1] - e0; - // inner(X) = f_prime(X) + nc_total(X) + // inner(X) = f_prime(X) inner.fill(K::ZERO); // f_prime(X): expand sparse polynomial with affine substitutions. for term in &self.f_terms { term_poly.fill(K::ZERO); term_poly[0] = term.coeff; + let mut current_deg = 0usize; for &(var_pos, exp) in &term.vars { - if exp == 0 { - continue; - } let tbl = &self.f_var_tables[var_pos]; // v(X) = a + b·X let a = tbl[2 * t]; let b = tbl[2 * t + 1] - a; for _ in 0..exp { - Self::poly_mul_affine_inplace(&mut term_poly, a, b); + Self::poly_mul_affine_inplace(&mut term_poly, a, b, current_deg); + current_deg += 1; } } - for i in 0..=deg_max { + for i in 0..=core::cmp::min(current_deg, deg_max) { inner[i] += term_poly[i]; } } - // NC: degree-5 in X before multiplying by eq_beta_r(X). - for w_i in 0..self.nc_tables.len() { - let lo = &self.nc_tables[w_i][2 * t]; - let hi = &self.nc_tables[w_i][2 * t + 1]; - let base = w_i * D; - - for rho in 0..D { - let wg = self.w_gamma_nc[base + rho]; - - // y(X) = a + b·X - let a = lo[rho]; - let b = hi[rho] - a; - - // Expand N(a + bX) = (a+bX)^5 - 5(a+bX)^3 + 4(a+bX). - // - // Coeffs: - // X^0: a^5 - 5a^3 + 4a - // X^1: b(5a^4 - 15a^2 + 4) - // X^2: b^2(10a^3 - 15a) - // X^3: b^3(10a^2 - 5) - // X^4: 5ab^4 - // X^5: b^5 - let a2 = a * a; - let a3 = a2 * a; - let a4 = a2 * a2; - let a5 = a4 * a; - - let b2 = b * b; - let b3 = b2 * b; - let b4 = b2 * b2; - let b5 = b4 * b; - - let t0 = a5 - a3.scale_base_k(five) + a.scale_base_k(four); - let p1 = a4.scale_base_k(five) - a2.scale_base_k(fifteen) + four; - let t1 = b * p1; - - let p2 = a3.scale_base_k(ten) - a.scale_base_k(fifteen); - let t2 = b2 * p2; - - let p3 = a2.scale_base_k(ten) - five; - let t3 = b3 * p3; - - let t4 = b4 * a.scale_base_k(five); - let t5 = b5; - - inner[0] += wg * t0; - inner[1] += wg * t1; - inner[2] += wg * t2; - inner[3] += wg * t3; - inner[4] += wg * t4; - inner[5] += wg * t5; - } - } - // coeffs += eq_beta_r(X) * inner(X) coeffs[0] += e0 * inner[0]; for d in 1..=deg_max { @@ -1780,84 +1675,29 @@ impl RowStreamState { let e0 = self.eq_beta_r_tbl[2 * t]; let e1 = self.eq_beta_r_tbl[2 * t + 1] - e0; - // inner(X) = f_prime(X) + nc_total(X) + // inner(X) = f_prime(X) inner.fill(K::ZERO); // f_prime(X): expand sparse polynomial with affine substitutions. for term in &self.f_terms { term_poly.fill(K::ZERO); term_poly[0] = term.coeff; + let mut current_deg = 0usize; for &(var_pos, exp) in &term.vars { - if exp == 0 { - continue; - } let tbl = &self.f_var_tables[var_pos]; // v(X) = a + b·X let a = tbl[2 * t]; let b = tbl[2 * t + 1] - a; for _ in 0..exp { - Self::poly_mul_affine_inplace(&mut term_poly, a, b); + Self::poly_mul_affine_inplace(&mut term_poly, a, b, current_deg); + current_deg += 1; } } - for i in 0..=deg_max { + for i in 0..=core::cmp::min(current_deg, deg_max) { inner[i] += term_poly[i]; } } - // NC: degree-5 in X before multiplying by eq_beta_r(X). - for w_i in 0..self.nc_tables.len() { - let lo = &self.nc_tables[w_i][2 * t]; - let hi = &self.nc_tables[w_i][2 * t + 1]; - let base = w_i * D; - - for rho in 0..D { - let wg = self.w_gamma_nc[base + rho]; - - // y(X) = a + b·X - let a = lo[rho]; - let b = hi[rho] - a; - - // Expand N(a + bX) = (a+bX)^5 - 5(a+bX)^3 + 4(a+bX). - // - // Coeffs: - // X^0: a^5 - 5a^3 + 4a - // X^1: b(5a^4 - 15a^2 + 4) - // X^2: b^2(10a^3 - 15a) - // X^3: b^3(10a^2 - 5) - // X^4: 5ab^4 - // X^5: b^5 - let a2 = a * a; - let a3 = a2 * a; - let a4 = a2 * a2; - let a5 = a4 * a; - - let b2 = b * b; - let b3 = b2 * b; - let b4 = b2 * b2; - let b5 = b4 * b; - - let t0 = a5 - a3.scale_base_k(five) + a.scale_base_k(four); - let p1 = a4.scale_base_k(five) - a2.scale_base_k(fifteen) + four; - let t1 = b * p1; - - let p2 = a3.scale_base_k(ten) - a.scale_base_k(fifteen); - let t2 = b2 * p2; - - let p3 = a2.scale_base_k(ten) - five; - let t3 = b3 * p3; - - let t4 = b4 * a.scale_base_k(five); - let t5 = b5; - - inner[0] += wg * t0; - inner[1] += wg * t1; - inner[2] += wg * t2; - inner[3] += wg * t3; - inner[4] += wg * t4; - inner[5] += wg * t5; - } - } - // coeffs += eq_beta_r(X) * inner(X) coeffs[0] += e0 * inner[0]; for d in 1..=deg_max { @@ -1927,21 +1767,7 @@ impl RowStreamState { f_prime += acc; } - // NC: Σ_i Σ_rho (w_beta_a[rho]·gamma_i) · N_i(y_i_rho) - let mut nc_total = K::ZERO; - for w_i in 0..self.nc_tables.len() { - let base = w_i * D; - for rho in 0..D { - let wg = self.w_gamma_nc[base + rho]; - let lo = self.nc_tables[w_i][2 * t][rho]; - let hi = self.nc_tables[w_i][2 * t + 1][rho]; - let y = one_minus * lo + x * hi; - let ni = range_product_cached(y, &self.range_t_sq); - nc_total += wg * ni; - } - } - - let mut out = eq_beta_r * (f_prime + nc_total); + let mut out = eq_beta_r * f_prime; // Eval: eq_r_inputs(r') * gamma_to_k * eval_tbl(r') if let (Some(eq_tbl), Some(eval_tbl)) = (self.eq_r_inputs_tbl.as_ref(), self.eval_tbl.as_ref()) { @@ -2163,7 +1989,6 @@ where ell_n, mcs_witnesses, me_witnesses, - false, r_inputs, sparse.as_ref(), ); @@ -2206,7 +2031,6 @@ where /// Precompute all data that depends only on r' (not on α') for row phase optimization. /// This eliminates redundant v_j recomputation across all boolean α' assignments. fn precompute_for_r(&self, r_prime: &[K]) -> RPrecomp { - let k_total = self.mcs_witnesses.len() + self.me_witnesses.len(); let t = self.s.t(); // Build χ_r table over the Boolean row domain. @@ -2285,9 +2109,6 @@ where } let f_prime = self.s.f.eval_in_ext::(&m_vals); - // Precompute Y_eval[i][j][ρ] = (Z_i · v_j)[ρ] for all instances and matrices - let mut y_eval = vec![vec![[K::ZERO; D]; t]; k_total]; - let all_witnesses: Vec<&Mat> = self .mcs_witnesses .iter() @@ -2295,17 +2116,53 @@ where .chain(self.me_witnesses.iter()) .collect(); - for (idx, Zi) in all_witnesses.iter().enumerate() { - for j in 0..t { - for rho in 0..D { - let mut acc = K::ZERO; - for &(c, v) in &vjs_nz[j] { - acc += v.scale_base_k(K::from(Zi[(rho, c)])); - } - y_eval[idx][j][rho] = acc; - } + // Precompute Y_eval[i][j][ρ] = (Z_i · v_j)[ρ] for all instances and matrices. + let y_eval: Vec> = { + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + all_witnesses + .par_iter() + .map(|Zi| { + (0..t) + .map(|j| { + let mut y_row = [K::ZERO; D]; + for rho in 0..D { + let mut acc = K::ZERO; + let z_row = Zi.row(rho); + for &(c, v) in &vjs_nz[j] { + acc += v.scale_base_k(K::from(z_row[c])); + } + y_row[rho] = acc; + } + y_row + }) + .collect() + }) + .collect() } - } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + all_witnesses + .iter() + .map(|Zi| { + (0..t) + .map(|j| { + let mut y_row = [K::ZERO; D]; + for rho in 0..D { + let mut acc = K::ZERO; + let z_row = Zi.row(rho); + for &(c, v) in &vjs_nz[j] { + acc += v.scale_base_k(K::from(z_row[c])); + } + y_row[rho] = acc; + } + y_row + }) + .collect() + }) + .collect() + } + }; RPrecomp { y_eval, diff --git a/crates/neo-reductions/src/engines/optimized_engine/sparse.rs b/crates/neo-reductions/src/engines/optimized_engine/sparse.rs index e784fa95..1d7b9288 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/sparse.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/sparse.rs @@ -163,7 +163,7 @@ impl CscMat { /// Only reads rows < n_eff. pub fn add_mul_transpose_into(&self, x: &[Kf], y: &mut [Kf], n_eff: usize) where - Kf: Copy + core::ops::AddAssign + core::ops::Mul + From, + Kf: Copy + core::ops::AddAssign + core::ops::Mul + From + Send + Sync, { debug_assert!(n_eff <= self.nrows, "n_eff must be <= nrows"); debug_assert!( @@ -174,14 +174,40 @@ impl CscMat { ); debug_assert_eq!(y.len(), self.ncols); - for c in 0..self.ncols { - let s = self.col_ptr[c]; - let e = self.col_ptr[c + 1]; - for k in s..e { - let r = self.row_idx[k]; - if r < n_eff { - y[c] += Kf::from(self.vals[k]) * x[r]; + #[cfg(any(not(target_arch = "wasm32"), feature = "wasm-threads"))] + { + y.par_iter_mut().enumerate().for_each(|(c, yc)| { + let s = self.col_ptr[c]; + let e = self.col_ptr[c + 1]; + if s == e { + return; + } + let mut sum = *yc; + for k in s..e { + let r = self.row_idx[k]; + if r < n_eff { + sum += Kf::from(self.vals[k]) * x[r]; + } + } + *yc = sum; + }); + } + #[cfg(all(target_arch = "wasm32", not(feature = "wasm-threads")))] + { + for c in 0..self.ncols { + let s = self.col_ptr[c]; + let e = self.col_ptr[c + 1]; + if s == e { + continue; } + let mut sum = y[c]; + for k in s..e { + let r = self.row_idx[k]; + if r < n_eff { + sum += Kf::from(self.vals[k]) * x[r]; + } + } + y[c] = sum; } } } @@ -190,14 +216,18 @@ impl CscMat { /// Only updates rows < n_eff. pub fn add_mul_into(&self, x: &[Kf], y: &mut [Kf], n_eff: usize) where - Kf: Copy + core::ops::AddAssign + core::ops::Mul + From, + Kf: Copy + core::ops::AddAssign + core::ops::Mul + From + PartialEq, { debug_assert!(n_eff <= self.nrows, "n_eff must be <= nrows"); debug_assert!(y.len() >= n_eff, "y.len() must be >= n_eff"); debug_assert_eq!(x.len(), self.ncols); + let zero = Kf::from(Ff::ZERO); for c in 0..self.ncols { let xc = x[c]; + if xc == zero { + continue; + } let s = self.col_ptr[c]; let e = self.col_ptr[c + 1]; for k in s..e { diff --git a/crates/neo-reductions/tests/nc_col_fast_path_equivalence.rs b/crates/neo-reductions/tests/nc_col_fast_path_equivalence.rs new file mode 100644 index 00000000..aa671db5 --- /dev/null +++ b/crates/neo-reductions/tests/nc_col_fast_path_equivalence.rs @@ -0,0 +1,84 @@ +#![allow(non_snake_case)] + +use neo_ccs::{CcsStructure, Mat, McsWitness, SparsePoly}; +use neo_math::{D, F, K}; +use neo_params::NeoParams; +use neo_reductions::engines::utils::build_dims_and_policy; +use neo_reductions::optimized_engine::oracle::NcOracle; +use neo_reductions::optimized_engine::Challenges; +use neo_reductions::sumcheck::RoundOracle; +use p3_field::PrimeCharacteristicRing; + +fn identity_left(n: usize, m: usize) -> Mat { + let mut mat = Mat::zero(n, m, F::ZERO); + for i in 0..n.min(m) { + mat.set(i, i, F::ONE); + } + mat +} + +fn run_fast_vs_generic(b: u32) { + let n = 4usize; + let m = 8usize; + + let mut params = NeoParams::goldilocks_auto_r1cs_ccs(n).expect("params"); + params.b = b; + + let s = CcsStructure::new(vec![identity_left(n, m)], SparsePoly::new(1, vec![])).expect("ccs"); + let dims = build_dims_and_policy(¶ms, &s).expect("dims"); + + let mut data = Vec::with_capacity(D * m); + for rho in 0..D { + for c in 0..m { + data.push(F::from_u64(7 + (rho as u64) * 19 + (c as u64) * 23)); + } + } + let Z = Mat::from_row_major(D, m, data); + let mcs_witnesses = vec![McsWitness { w: vec![F::ZERO; m], Z }]; + + let ch = Challenges { + alpha: (0..dims.ell_d) + .map(|i| K::from(F::from_u64(100 + i as u64))) + .collect(), + beta_a: (0..dims.ell_d) + .map(|i| K::from(F::from_u64(200 + i as u64))) + .collect(), + beta_r: (0..dims.ell_n) + .map(|i| K::from(F::from_u64(300 + i as u64))) + .collect(), + beta_m: (0..dims.ell_m) + .map(|i| K::from(F::from_u64(400 + i as u64))) + .collect(), + gamma: K::from(F::from_u64(777)), + }; + + let mut oracle = NcOracle::new(&s, ¶ms, &mcs_witnesses, &[], ch, dims.ell_d, dims.ell_m, dims.d_sc); + let xs = vec![ + K::from(F::ZERO), + K::from(F::ONE), + K::from(F::from_u64(2)), + K::from(F::from_u64(5)), + K::from(F::from_u64(9)), + ]; + + for round in 0..dims.ell_m { + let (fast, generic) = oracle + .__test_col_phase_fast_vs_generic(&xs) + .expect("must be in NC column phase"); + assert_eq!( + fast, generic, + "NcOracle fast col-phase mismatch at b={b}, round={round}" + ); + oracle.fold(K::from(F::from_u64(900 + round as u64))); + } +} + +#[test] +fn nc_col_phase_fast_path_matches_generic_b2() { + run_fast_vs_generic(2); +} + +#[test] +fn nc_col_phase_fast_path_matches_generic_b3() { + run_fast_vs_generic(3); +} diff --git a/crates/neo-spartan-bridge/tests/fold_run_circuit_smoke.rs b/crates/neo-spartan-bridge/tests/fold_run_circuit_smoke.rs index b19fc588..968c6184 100644 --- a/crates/neo-spartan-bridge/tests/fold_run_circuit_smoke.rs +++ b/crates/neo-spartan-bridge/tests/fold_run_circuit_smoke.rs @@ -123,7 +123,9 @@ fn build_trivial_fold_run_and_instance() -> (FoldRunInstance, FoldRunWitness) { steps: vec![StepProof { fold: step, mem: MemSidecarProof { - cpu_me_claims_val: Vec::new(), + val_me_claims: Vec::new(), + wb_me_claims: Vec::new(), + wp_me_claims: Vec::new(), shout_addr_pre: Default::default(), proofs: Vec::new(), }, @@ -133,7 +135,9 @@ fn build_trivial_fold_run_and_instance() -> (FoldRunInstance, FoldRunWitness) { labels: Vec::new(), round_polys: Vec::new(), }, - val_fold: None, + val_fold: Vec::new(), + wb_fold: Vec::new(), + wp_fold: Vec::new(), }], output_proof: None, }; @@ -286,7 +290,7 @@ fn fold_run_circuit_optimized_nontrivial_satisfied() { }; let initial_accumulator = acc.me.clone(); - let mut session = FoldingSession::new(FoldingMode::PaperExact, params, l.clone()) + let mut session = FoldingSession::new(FoldingMode::Optimized, params, l.clone()) .with_initial_accumulator(acc, &ccs) .expect("with_initial_accumulator"); diff --git a/demos/wasm-demo/README.md b/demos/wasm-demo/README.md index b1095ded..00273e57 100644 --- a/demos/wasm-demo/README.md +++ b/demos/wasm-demo/README.md @@ -9,7 +9,7 @@ This demo supports two input modes: - **TestExport JSON** (same schema as `crates/neo-fold/poseidon2-tests/*.json`): R1CS A/B/C sparse matrices + per-step witnesses. - **RV32 Fibonacci** (pasteable “mini-asm” / words): assembles a small RV32 program, generates a trace, - and proves it under the RV32 B1 shared-bus step circuit. + and proves it under the RV32 trace-wiring shared-bus circuit. ## API surface / extending from JS diff --git a/demos/wasm-demo/wasm/src/lib.rs b/demos/wasm-demo/wasm/src/lib.rs index 686bb64c..6845a536 100644 --- a/demos/wasm-demo/wasm/src/lib.rs +++ b/demos/wasm-demo/wasm/src/lib.rs @@ -72,18 +72,18 @@ fn fib_u32(n: u32) -> u32 { a } -/// Prove+verify the RV32 Fibonacci program under the B1 shared-bus step circuit. +/// Prove+verify the RV32 Fibonacci program under trace-wiring mode. /// /// Expected guest semantics: /// - reads `n` from RAM[0x104] (u32) /// - writes `fib(n)` to RAM[0x100] (u32) /// - halts via `ecall` (treated as `Halt` in this VM) #[wasm_bindgen] -pub fn prove_verify_rv32_b1_fibonacci_asm( +pub fn prove_verify_rv32_trace_fibonacci_asm( asm: &str, n: u32, - ram_bytes: usize, - chunk_size: usize, + _ram_bytes: usize, + chunk_rows: usize, max_steps: usize, do_spartan: bool, ) -> Result { @@ -96,11 +96,10 @@ pub fn prove_verify_rv32_b1_fibonacci_asm( let expected_f = F::from_u64(expected as u64); let mut run = { - let mut b = neo_fold::riscv_shard::Rv32B1::from_rom(/*program_base=*/ 0, &program_bytes) + let mut b = neo_fold::riscv_trace_shard::Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) .xlen(32) - .ram_bytes(ram_bytes) .ram_init_u32(/*addr=*/ 0x104, n) - .chunk_size(chunk_size) + .chunk_rows(chunk_rows) .shout_auto_minimal() .output(/*output_addr=*/ 0x100, /*expected_output=*/ expected_f); if max_steps > 0 { @@ -118,11 +117,11 @@ pub fn prove_verify_rv32_b1_fibonacci_asm( .map(|d| d.as_secs_f64() * 1000.0) .unwrap_or(0.0); - let trace_len = run.riscv_trace_len().ok(); + let trace_len = Some(run.trace_len()); let folds = run.fold_count(); let ccs_constraints = run.ccs_num_constraints(); let ccs_variables = run.ccs_num_variables(); - let shout_lookups = run.shout_lookup_count().ok(); + let shout_lookups = Some(run.exec_table().rows.iter().map(|r| r.shout_events.len()).sum()); let spartan = if do_spartan { let acc_init = &[]; diff --git a/demos/wasm-demo/web/index.html b/demos/wasm-demo/web/index.html index 8072e51d..bda34b4b 100644 --- a/demos/wasm-demo/web/index.html +++ b/demos/wasm-demo/web/index.html @@ -89,7 +89,7 @@

Controls

RISC-V mode expects a small “mini-asm” subset (or one 32-bit word per line). diff --git a/demos/wasm-demo/web/pkg/neo_fold_demo.d.ts b/demos/wasm-demo/web/pkg/neo_fold_demo.d.ts index d1fb1bc5..2f2f0f03 100644 --- a/demos/wasm-demo/web/pkg/neo_fold_demo.d.ts +++ b/demos/wasm-demo/web/pkg/neo_fold_demo.d.ts @@ -89,14 +89,14 @@ export class SpartanCompressedProof { export function init_panic_hook(): void; /** - * Prove+verify the RV32 Fibonacci program under the B1 shared-bus step circuit. + * Prove+verify the RV32 Fibonacci program under the trace-wiring shared-bus circuit. * * Expected guest semantics: * - reads `n` from RAM[0x104] (u32) * - writes `fib(n)` to RAM[0x100] (u32) * - halts via `ecall` (treated as `Halt` in this VM) */ -export function prove_verify_rv32_b1_fibonacci_asm(asm: string, n: number, ram_bytes: number, chunk_size: number, max_steps: number, do_spartan: boolean): any; +export function prove_verify_rv32_trace_fibonacci_asm(asm: string, n: number, ram_bytes: number, chunk_size: number, max_steps: number, do_spartan: boolean): any; /** * Parse a `TestExport` JSON (same schema as `crates/neo-fold/poseidon2-tests/*.json`), @@ -110,7 +110,7 @@ export interface InitOutput { readonly memory: WebAssembly.Memory; readonly init_panic_hook: () => void; readonly prove_verify_test_export_json: (a: number, b: number) => [number, number, number]; - readonly prove_verify_rv32_b1_fibonacci_asm: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => [number, number, number]; + readonly prove_verify_rv32_trace_fibonacci_asm: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => [number, number, number]; readonly __wbg_neofoldsession_free: (a: number, b: number) => void; readonly neofoldsession_new: (a: number, b: number) => [number, number, number]; readonly neofoldsession_step_count: (a: number) => number; diff --git a/demos/wasm-demo/web/pkg/neo_fold_demo.js b/demos/wasm-demo/web/pkg/neo_fold_demo.js index bf949831..ddc7cd6a 100644 --- a/demos/wasm-demo/web/pkg/neo_fold_demo.js +++ b/demos/wasm-demo/web/pkg/neo_fold_demo.js @@ -301,7 +301,7 @@ export function init_panic_hook() { } /** - * Prove+verify the RV32 Fibonacci program under the B1 shared-bus step circuit. + * Prove+verify the RV32 Fibonacci program under the trace-wiring shared-bus circuit. * * Expected guest semantics: * - reads `n` from RAM[0x104] (u32) @@ -315,10 +315,10 @@ export function init_panic_hook() { * @param {boolean} do_spartan * @returns {any} */ -export function prove_verify_rv32_b1_fibonacci_asm(asm, n, ram_bytes, chunk_size, max_steps, do_spartan) { +export function prove_verify_rv32_trace_fibonacci_asm(asm, n, ram_bytes, chunk_size, max_steps, do_spartan) { const ptr0 = passStringToWasm0(asm, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc); const len0 = WASM_VECTOR_LEN; - const ret = wasm.prove_verify_rv32_b1_fibonacci_asm(ptr0, len0, n, ram_bytes, chunk_size, max_steps, do_spartan); + const ret = wasm.prove_verify_rv32_trace_fibonacci_asm(ptr0, len0, n, ram_bytes, chunk_size, max_steps, do_spartan); if (ret[2]) { throw takeFromExternrefTable0(ret[1]); } diff --git a/demos/wasm-demo/web/pkg/neo_fold_demo_bg.wasm.d.ts b/demos/wasm-demo/web/pkg/neo_fold_demo_bg.wasm.d.ts index 4a86c956..ed87ad2c 100644 --- a/demos/wasm-demo/web/pkg/neo_fold_demo_bg.wasm.d.ts +++ b/demos/wasm-demo/web/pkg/neo_fold_demo_bg.wasm.d.ts @@ -3,7 +3,7 @@ export const memory: WebAssembly.Memory; export const init_panic_hook: () => void; export const prove_verify_test_export_json: (a: number, b: number) => [number, number, number]; -export const prove_verify_rv32_b1_fibonacci_asm: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => [number, number, number]; +export const prove_verify_rv32_trace_fibonacci_asm: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => [number, number, number]; export const __wbg_neofoldsession_free: (a: number, b: number) => void; export const neofoldsession_new: (a: number, b: number) => [number, number, number]; export const neofoldsession_step_count: (a: number) => number; diff --git a/demos/wasm-demo/web/pkg_threads/neo_fold_demo.d.ts b/demos/wasm-demo/web/pkg_threads/neo_fold_demo.d.ts index 63057b1b..c81bae85 100644 --- a/demos/wasm-demo/web/pkg_threads/neo_fold_demo.d.ts +++ b/demos/wasm-demo/web/pkg_threads/neo_fold_demo.d.ts @@ -93,14 +93,14 @@ export function init_panic_hook(): void; export function init_thread_pool(num_threads: number): Promise; /** - * Prove+verify the RV32 Fibonacci program under the B1 shared-bus step circuit. + * Prove+verify the RV32 Fibonacci program under the trace-wiring shared-bus circuit. * * Expected guest semantics: * - reads `n` from RAM[0x104] (u32) * - writes `fib(n)` to RAM[0x100] (u32) * - halts via `ecall` (treated as `Halt` in this VM) */ -export function prove_verify_rv32_b1_fibonacci_asm(asm: string, n: number, ram_bytes: number, chunk_size: number, max_steps: number, do_spartan: boolean): any; +export function prove_verify_rv32_trace_fibonacci_asm(asm: string, n: number, ram_bytes: number, chunk_size: number, max_steps: number, do_spartan: boolean): any; /** * Parse a `TestExport` JSON (same schema as `crates/neo-fold/poseidon2-tests/*.json`), @@ -145,7 +145,7 @@ export interface InitOutput { readonly neofoldsession_spartan_verify: (a: number, b: number) => [number, number, number]; readonly neofoldsession_step_count: (a: number) => number; readonly neofoldsession_verify: (a: number, b: number) => [number, number, number]; - readonly prove_verify_rv32_b1_fibonacci_asm: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => [number, number, number]; + readonly prove_verify_rv32_trace_fibonacci_asm: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => [number, number, number]; readonly prove_verify_test_export_json: (a: number, b: number) => [number, number, number]; readonly spartancompressedproof_bytes: (a: number) => [number, number, number, number]; readonly spartancompressedproof_bytes_len: (a: number) => [number, number, number]; diff --git a/demos/wasm-demo/web/pkg_threads/neo_fold_demo.js b/demos/wasm-demo/web/pkg_threads/neo_fold_demo.js index 01190588..1d0e60bc 100644 --- a/demos/wasm-demo/web/pkg_threads/neo_fold_demo.js +++ b/demos/wasm-demo/web/pkg_threads/neo_fold_demo.js @@ -320,7 +320,7 @@ export function init_thread_pool(num_threads) { } /** - * Prove+verify the RV32 Fibonacci program under the B1 shared-bus step circuit. + * Prove+verify the RV32 Fibonacci program under the trace-wiring shared-bus circuit. * * Expected guest semantics: * - reads `n` from RAM[0x104] (u32) @@ -334,10 +334,10 @@ export function init_thread_pool(num_threads) { * @param {boolean} do_spartan * @returns {any} */ -export function prove_verify_rv32_b1_fibonacci_asm(asm, n, ram_bytes, chunk_size, max_steps, do_spartan) { +export function prove_verify_rv32_trace_fibonacci_asm(asm, n, ram_bytes, chunk_size, max_steps, do_spartan) { const ptr0 = passStringToWasm0(asm, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc); const len0 = WASM_VECTOR_LEN; - const ret = wasm.prove_verify_rv32_b1_fibonacci_asm(ptr0, len0, n, ram_bytes, chunk_size, max_steps, do_spartan); + const ret = wasm.prove_verify_rv32_trace_fibonacci_asm(ptr0, len0, n, ram_bytes, chunk_size, max_steps, do_spartan); if (ret[2]) { throw takeFromExternrefTable0(ret[1]); } diff --git a/demos/wasm-demo/web/pkg_threads/neo_fold_demo_bg.wasm.d.ts b/demos/wasm-demo/web/pkg_threads/neo_fold_demo_bg.wasm.d.ts index f52c72b7..9383ba76 100644 --- a/demos/wasm-demo/web/pkg_threads/neo_fold_demo_bg.wasm.d.ts +++ b/demos/wasm-demo/web/pkg_threads/neo_fold_demo_bg.wasm.d.ts @@ -22,7 +22,7 @@ export const neofoldsession_spartan_prove: (a: number, b: number) => [number, nu export const neofoldsession_spartan_verify: (a: number, b: number) => [number, number, number]; export const neofoldsession_step_count: (a: number) => number; export const neofoldsession_verify: (a: number, b: number) => [number, number, number]; -export const prove_verify_rv32_b1_fibonacci_asm: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => [number, number, number]; +export const prove_verify_rv32_trace_fibonacci_asm: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => [number, number, number]; export const prove_verify_test_export_json: (a: number, b: number) => [number, number, number]; export const spartancompressedproof_bytes: (a: number) => [number, number, number, number]; export const spartancompressedproof_bytes_len: (a: number) => [number, number, number]; diff --git a/demos/wasm-demo/web/prover_worker.js b/demos/wasm-demo/web/prover_worker.js index e958d20a..fc1b54f8 100644 --- a/demos/wasm-demo/web/prover_worker.js +++ b/demos/wasm-demo/web/prover_worker.js @@ -445,14 +445,14 @@ async function runRv32Fibonacci({ id, asm, riscv, doSpartan, bundle, threads }) if (!Number.isFinite(chunkSize) || chunkSize <= 0) throw new Error("Invalid riscv.chunk_size"); if (!Number.isFinite(maxSteps) || maxSteps < 0) throw new Error("Invalid riscv.max_steps"); - log(id, "Running RV32 Fibonacci prove+verify…"); + log(id, "Running RV32 trace Fibonacci prove+verify…"); log(id, `Input text size: ${fmtBytes(src.length)}`); log(id, `Config: n=${n} ram_bytes=${ramBytes} chunk_size=${chunkSize} max_steps=${maxSteps}`); async function runOnce() { phase(id, "Proving…"); const totalStart = performance.now(); - const result = wasm.prove_verify_rv32_b1_fibonacci_asm( + const result = wasm.prove_verify_rv32_trace_fibonacci_asm( src, n, ramBytes, @@ -486,7 +486,7 @@ async function runRv32Fibonacci({ id, asm, riscv, doSpartan, bundle, threads }) notifyThreadsDisabled(id, threadsDisableReason); await ensureWasm({ id, bundle: "pkg", threads: 0 }); - log(id, "Retrying RV32 Fibonacci prove+verify in single-thread mode…", "warn"); + log(id, "Retrying RV32 trace Fibonacci prove+verify in single-thread mode…", "warn"); ({ result, totalMs } = await runOnce()); } else { throw e; diff --git a/file_aggregator.sh b/file_aggregator.sh index 9f2a4ef5..c67bea46 100755 --- a/file_aggregator.sh +++ b/file_aggregator.sh @@ -126,6 +126,46 @@ format_number() { }' } +# --------------------------------------------------------------------------- +# Token counting — prefer tiktoken (accurate) with bytes/4 fallback +# --------------------------------------------------------------------------- +# Tries to use ~/smart-context-packer/token_counter.py with its venv for exact +# BPE token counts. If the script or tiktoken is unavailable, falls back to +# the bytes/4 heuristic silently. +TIKTOKEN_PYTHON="$HOME/smart-context-packer/.venv/bin/python3" +TIKTOKEN_SCRIPT="$HOME/smart-context-packer/token_counter.py" +HAS_TIKTOKEN=false + +if [ -x "$TIKTOKEN_PYTHON" ] && [ -f "$TIKTOKEN_SCRIPT" ]; then + # Quick smoke test: make sure tiktoken is importable + if "$TIKTOKEN_PYTHON" -c "import tiktoken" 2>/dev/null; then + HAS_TIKTOKEN=true + fi +fi + +# Count tokens for a file. Outputs a single integer. +count_tokens_file() { + local file="$1" + if $HAS_TIKTOKEN; then + "$TIKTOKEN_PYTHON" "$TIKTOKEN_SCRIPT" "$file" 2>/dev/null + else + local size + size=$(wc -c < "$file") + echo $((size / 4)) + fi +} + +# Count tokens for text on stdin. Outputs a single integer. +count_tokens_stdin() { + if $HAS_TIKTOKEN; then + "$TIKTOKEN_PYTHON" "$TIKTOKEN_SCRIPT" 2>/dev/null + else + local size + size=$(wc -c) + echo $((size / 4)) + fi +} + # Remove trailing slashes from all directories for i in "${!dirs[@]}"; do dirs[$i]="${dirs[$i]%/}" @@ -161,7 +201,10 @@ for dir in "${dirs[@]}"; do done < <(eval "$find_cmd") done -# Deduplicate while preserving input order, then process +# Deduplicate while preserving input order, then process. +# Per-file token counts use the fast bytes/4 estimate (calling Python per-file +# would be too slow). The accurate tiktoken count is done once at the end on +# the complete output file. printf '%s\n' "${files_to_process[@]}" | awk '!seen[$0]++' | while IFS= read -r file; do [ -e "$file" ] || continue # Get size before processing this file @@ -180,7 +223,7 @@ printf '%s\n' "${files_to_process[@]}" | awk '!seen[$0]++' | while IFS= read -r cat "$file" >> "$outfile" fi echo >> "$outfile" - # Calculate this file's token contribution + # Per-file estimate (bytes/4 — fast, just for progress display) size_after=$(wc -c < "$outfile") file_size=$((size_after - size_before)) file_tokens=$((file_size / 4)) @@ -191,12 +234,20 @@ done if [ -f "$outfile" ]; then final_size=$(wc -c < "$outfile") final_words=$(wc -w < "$outfile") - # Approximate AI token count (1 token ≈ 4 characters for English text) - final_ai_tokens=$((final_size / 4)) + + # Token count: use tiktoken (accurate) if available, else bytes/4 estimate + if $HAS_TIKTOKEN; then + final_ai_tokens=$(count_tokens_file "$outfile") + token_label="tiktoken" + else + final_ai_tokens=$((final_size / 4)) + token_label="est" + fi + echo echo "=== Final Output Statistics ===" echo "Output file: $outfile" echo "Total size: $(format_number $final_size) bytes" echo "Total words: $(format_number $final_words)" - echo "Total AI tokens (est): $(format_number $final_ai_tokens)" + echo "Total AI tokens ($token_label): $(format_number $final_ai_tokens)" fi diff --git a/show_diff.sh b/show_diff.sh index 1a3c6997..6c1f3c05 100755 --- a/show_diff.sh +++ b/show_diff.sh @@ -108,10 +108,14 @@ fi echo "Git Status Summary" echo "==============================================" if [ ${#paths[@]} -gt 0 ]; then - git status --short -- "${paths[@]}" + status_output=$(git status --short -- "${paths[@]}") else - git status --short + status_output=$(git status --short) fi + if [ "$no_tests" = true ] && [ -n "$status_output" ]; then + status_output=$(filter_test_files "$status_output") + fi + echo "$status_output" echo "" } >> "$output_file"