From 809fe178948368e998fcc79eb289ae87c1a1e783 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Wed, 18 Feb 2026 13:29:04 -0600 Subject: [PATCH 1/2] mostly clean up twist oracles bloat Signed-off-by: Nico Arqueros --- ...v_fibonacci_compiled_trace_prove_verify.rs | 53 + ...scv_program_compiled_trace_prove_verify.rs | 49 + ..._riscv_program_rv32m_trace_prove_verify.rs | 215 + ..._u64_output_compiled_trace_prove_verify.rs | 49 + .../src/memory_sidecar/route_a_time.rs | 499 +- .../riscv_trace_ccs_diverse_programs.rs | 257 + .../shared_bus/trace_bus_binding_redteam.rs | 113 + .../shared_bus/trace_main_proof_redteam.rs | 137 + .../shared_bus/trace_twist_shout_redteam.rs | 172 + crates/neo-memory/src/builder.rs | 31 +- crates/neo-memory/src/cpu/constraints.rs | 120 +- crates/neo-memory/src/twist_oracle.rs | 7669 ++++------------- 12 files changed, 3207 insertions(+), 6157 deletions(-) create mode 100644 crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_trace_prove_verify.rs create mode 100644 crates/neo-fold/riscv-tests/test_riscv_program_compiled_trace_prove_verify.rs create mode 100644 crates/neo-fold/riscv-tests/test_riscv_program_rv32m_trace_prove_verify.rs create mode 100644 crates/neo-fold/riscv-tests/test_riscv_u64_output_compiled_trace_prove_verify.rs create mode 100644 crates/neo-fold/tests/suites/integration/riscv_trace_ccs_diverse_programs.rs create mode 100644 crates/neo-fold/tests/suites/shared_bus/trace_bus_binding_redteam.rs create mode 100644 crates/neo-fold/tests/suites/shared_bus/trace_main_proof_redteam.rs create mode 100644 crates/neo-fold/tests/suites/shared_bus/trace_twist_shout_redteam.rs diff --git a/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_trace_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_trace_prove_verify.rs new file mode 100644 index 00000000..754046a3 --- /dev/null +++ b/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_trace_prove_verify.rs @@ -0,0 +1,53 @@ +//! End-to-end prove+verify for a compiled RV32 Fibonacci guest program under the trace-mode runner. +//! +//! The guest is authored in Rust under `riscv-tests/guests/rv32-fibonacci/` and its ROM bytes are committed in +//! `riscv-tests/binaries/rv32_fibonacci_rom.rs` so this test doesn't need to cross-compile at runtime. + +#[path = "binaries/rv32_fibonacci_rom.rs"] +mod rv32_fibonacci_rom; + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn test_riscv_fibonacci_compiled_trace_prove_verify() { + // The guest reads n from RAM[0x104], computes fib(n), and writes the result to RAM[0x100]. + let n = 10u32; + let expected = F::from_u64(55); + + let program_base = rv32_fibonacci_rom::RV32_FIBONACCI_ROM_BASE; + let program_bytes: &[u8] = &rv32_fibonacci_rom::RV32_FIBONACCI_ROM; + + let mut run = Rv32TraceWiring::from_rom(program_base, program_bytes) + .xlen(32) + .max_steps(64) + .ram_init_u32(/*addr=*/ 0x104, n) + .shout_auto_minimal() + .output(/*output_addr=*/ 0x100, /*expected_output=*/ expected) + .prove() + .expect("trace-mode prove fibonacci"); + + run.verify().expect("trace-mode verify fibonacci"); + + // Wrong output must fail: prove with wrong expected value should fail at verify. + let wrong_run = Rv32TraceWiring::from_rom(program_base, program_bytes) + .xlen(32) + .max_steps(64) + .ram_init_u32(/*addr=*/ 0x104, n) + .shout_auto_minimal() + .output(/*output_addr=*/ 0x100, /*expected_output=*/ F::from_u64(56)) + .prove(); + + match wrong_run { + Ok(mut run_bad) => { + assert!( + run_bad.verify().is_err(), + "wrong output claim must not verify" + ); + } + Err(_) => { + // Prove itself failed, which is also acceptable. + } + } +} diff --git a/crates/neo-fold/riscv-tests/test_riscv_program_compiled_trace_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_program_compiled_trace_prove_verify.rs new file mode 100644 index 00000000..432d7786 --- /dev/null +++ b/crates/neo-fold/riscv-tests/test_riscv_program_compiled_trace_prove_verify.rs @@ -0,0 +1,49 @@ +//! End-to-end prove+verify for a compiled RV32 smoke-test guest program under the trace-mode runner. +//! +//! The guest is authored in Rust under `riscv-tests/guests/rv32-smoke/` and its ROM bytes are committed in +//! `riscv-tests/binaries/rv32_smoke_rom.rs` so this test doesn't need to cross-compile at runtime. + +#[path = "binaries/rv32_smoke_rom.rs"] +mod rv32_smoke_rom; + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn test_riscv_program_compiled_trace_prove_verify() { + let program_base = rv32_smoke_rom::RV32_SMOKE_ROM_BASE; + let program_bytes: &[u8] = &rv32_smoke_rom::RV32_SMOKE_ROM; + + let mut run = Rv32TraceWiring::from_rom(program_base, program_bytes) + .xlen(32) + .shout_auto_minimal() + .output( + /*output_addr=*/ 0x100, + /*expected_output=*/ F::from_u64(0x100c), + ) + .prove() + .expect("trace-mode prove compiled smoke"); + + run.verify().expect("trace-mode verify compiled smoke"); + + // Wrong output must fail. + let wrong_run = Rv32TraceWiring::from_rom(program_base, program_bytes) + .xlen(32) + .shout_auto_minimal() + .output( + /*output_addr=*/ 0x100, + /*expected_output=*/ F::from_u64(0x100d), + ) + .prove(); + + match wrong_run { + Ok(mut run_bad) => { + assert!( + run_bad.verify().is_err(), + "wrong output claim must not verify" + ); + } + Err(_) => {} + } +} diff --git a/crates/neo-fold/riscv-tests/test_riscv_program_rv32m_trace_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_program_rv32m_trace_prove_verify.rs new file mode 100644 index 00000000..13542e1f --- /dev/null +++ b/crates/neo-fold/riscv-tests/test_riscv_program_rv32m_trace_prove_verify.rs @@ -0,0 +1,215 @@ +//! End-to-end prove+verify for RV32M (M-extension) operations under the trace-mode runner. +//! +//! This validates that M-extension ops (MUL, DIV, etc.) are correctly handled +//! via Shout lookups in trace mode (table IDs 12-19). + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; + +#[test] +#[ignore = "M-ext Shout tables need closed-form MLE or a packed-key proof path for xlen=32"] +fn test_riscv_program_rv32m_trace_prove_verify() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: -6, + }, // x1 = -6 (0xFFFFFFFA) + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 3, + }, // x2 = 3 + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 3, + rs1: 1, + rs2: 2, + }, // x3 = (-6)*3 = -18 + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 4, + rs1: 1, + rs2: 2, + }, // x4 = (-6)/3 = -2 + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(32) + .min_trace_len(1) + .prove() + .expect("trace-mode prove with RV32M ops"); + + run.verify().expect("trace-mode verify with RV32M ops"); +} + +#[test] +#[ignore = "M-ext Shout tables need closed-form MLE or a packed-key proof path for xlen=32"] +fn test_riscv_program_rv32m_all_ops_trace_prove_verify() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 7, + }, // x1 = 7 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: 3, + }, // x2 = 3 + // MUL: x3 = 7*3 = 21 + RiscvInstruction::RAlu { + op: RiscvOpcode::Mul, + rd: 3, + rs1: 1, + rs2: 2, + }, + // MULH: x4 = high bits of signed(7)*signed(3) = 0 + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulh, + rd: 4, + rs1: 1, + rs2: 2, + }, + // MULHSU: x5 = high bits of signed(7)*unsigned(3) = 0 + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhsu, + rd: 5, + rs1: 1, + rs2: 2, + }, + // MULHU: x6 = high bits of unsigned(7)*unsigned(3) = 0 + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhu, + rd: 6, + rs1: 1, + rs2: 2, + }, + // DIV: x7 = 7/3 = 2 + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 7, + rs1: 1, + rs2: 2, + }, + // DIVU: x8 = 7/3 = 2 (unsigned) + RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + rd: 8, + rs1: 1, + rs2: 2, + }, + // REM: x9 = 7%3 = 1 + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 9, + rs1: 1, + rs2: 2, + }, + // REMU: x10 = 7%3 = 1 (unsigned) + RiscvInstruction::RAlu { + op: RiscvOpcode::Remu, + rd: 10, + rs1: 1, + rs2: 2, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(32) + .min_trace_len(1) + .prove() + .expect("trace-mode prove with all RV32M ops"); + + run.verify() + .expect("trace-mode verify with all RV32M ops"); +} + +#[test] +#[ignore = "M-ext Shout tables need closed-form MLE or a packed-key proof path for xlen=32"] +fn test_riscv_program_rv32m_signed_edge_cases_trace_prove_verify() { + let program = vec![ + // x1 = -1 (0xFFFFFFFF) + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: -1, + }, + // x2 = -1 + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 0, + imm: -1, + }, + // MULH(-1, -1): high bits of 1 = 0 + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulh, + rd: 3, + rs1: 1, + rs2: 2, + }, + // MULHSU(-1, 0xFFFFFFFF): high bits of signed(-1)*unsigned(0xFFFFFFFF) = -1 + RiscvInstruction::RAlu { + op: RiscvOpcode::Mulhsu, + rd: 4, + rs1: 1, + rs2: 2, + }, + // DIV(-1, -1) = 1 + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 5, + rs1: 1, + rs2: 2, + }, + // x6 = 0 (divisor = 0 case) + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 6, + rs1: 0, + imm: 0, + }, + // DIV(-1, 0) = -1 (RISC-V spec: division by zero returns -1) + RiscvInstruction::RAlu { + op: RiscvOpcode::Div, + rd: 7, + rs1: 1, + rs2: 6, + }, + // DIVU(-1, 0) = 0xFFFFFFFF (RISC-V spec) + RiscvInstruction::RAlu { + op: RiscvOpcode::Divu, + rd: 8, + rs1: 1, + rs2: 6, + }, + // REM(-1, 0) = -1 (RISC-V spec: remainder with divisor 0 = dividend) + RiscvInstruction::RAlu { + op: RiscvOpcode::Rem, + rd: 9, + rs1: 1, + rs2: 6, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_bytes) + .xlen(32) + .min_trace_len(1) + .prove() + .expect("trace-mode prove with signed M-ext edge cases"); + + run.verify() + .expect("trace-mode verify with signed M-ext edge cases"); +} diff --git a/crates/neo-fold/riscv-tests/test_riscv_u64_output_compiled_trace_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_u64_output_compiled_trace_prove_verify.rs new file mode 100644 index 00000000..2a3bbc8e --- /dev/null +++ b/crates/neo-fold/riscv-tests/test_riscv_u64_output_compiled_trace_prove_verify.rs @@ -0,0 +1,49 @@ +//! End-to-end prove+verify for a compiled RV32 guest with u64 output under the trace-mode runner. +//! +//! The guest is authored in Rust under `riscv-tests/guests/rv32-u64-output/` and its ROM bytes are committed in +//! `riscv-tests/binaries/rv32_u64_output_rom.rs` so this test doesn't need to cross-compile at runtime. + +#[path = "binaries/rv32_u64_output_rom.rs"] +mod rv32_u64_output_rom; + +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_math::F; +use p3_field::PrimeCharacteristicRing; + +#[test] +fn test_riscv_u64_output_compiled_trace_prove_verify() { + let output = 0x1122_3344_5566_7788u64; + let out_lo = F::from_u64(output as u32 as u64); + let out_hi = F::from_u64((output >> 32) as u32 as u64); + + let program_base = rv32_u64_output_rom::RV32_U64_OUTPUT_ROM_BASE; + let program_bytes: &[u8] = &rv32_u64_output_rom::RV32_U64_OUTPUT_ROM; + + let mut run = Rv32TraceWiring::from_rom(program_base, program_bytes) + .xlen(32) + .shout_auto_minimal() + .output_claim(/*addr=*/ 0x100, /*value=*/ out_lo) + .output_claim(/*addr=*/ 0x104, /*value=*/ out_hi) + .prove() + .expect("trace-mode prove u64 output"); + + run.verify().expect("trace-mode verify u64 output"); + + // Wrong output must fail. + let wrong_run = Rv32TraceWiring::from_rom(program_base, program_bytes) + .xlen(32) + .shout_auto_minimal() + .output_claim(/*addr=*/ 0x100, /*value=*/ out_lo) + .output_claim(/*addr=*/ 0x104, /*value=*/ F::from_u64(0)) + .prove(); + + match wrong_run { + Ok(mut run_bad) => { + assert!( + run_bad.verify().is_err(), + "wrong output claim must not verify" + ); + } + Err(_) => {} + } +} diff --git a/crates/neo-fold/src/memory_sidecar/route_a_time.rs b/crates/neo-fold/src/memory_sidecar/route_a_time.rs index 818b445a..c5c450a1 100644 --- a/crates/neo-fold/src/memory_sidecar/route_a_time.rs +++ b/crates/neo-fold/src/memory_sidecar/route_a_time.rs @@ -25,6 +25,45 @@ pub struct ExtraBatchedTimeClaim { pub label: &'static [u8], } +fn split_extra_claim( + claim: Option, +) -> (Option>, Option<&'static [u8]>, Option) { + match claim { + Some(extra) => (Some(extra.oracle), Some(extra.label), Some(extra.claimed_sum)), + None => (None, None, None), + } +} + +fn append_optional_claim<'a>( + oracle: &'a mut Option>, + label: Option<&'static [u8]>, + claimed_sum: Option, + is_dynamic: bool, + claimed_sums: &mut Vec, + degree_bounds: &mut Vec, + labels: &mut Vec<&'static [u8]>, + claim_is_dynamic: &mut Vec, + claims: &mut Vec>, + missing_label_msg: &'static str, + missing_claimed_sum_msg: &'static str, +) { + if let Some(oracle) = oracle.as_deref_mut() { + let label = label.expect(missing_label_msg); + let claimed_sum = claimed_sum.expect(missing_claimed_sum_msg); + claimed_sums.push(claimed_sum); + degree_bounds.push(oracle.degree_bound()); + labels.push(label); + claim_is_dynamic.push(is_dynamic); + claims.push(BatchedClaim { + oracle, + claimed_sum, + label, + }); + } else { + debug_assert!(label.is_none(), "label present without oracle"); + } +} + pub fn prove_route_a_batched_time( tr: &mut Poseidon2Transcript, step_idx: usize, @@ -109,321 +148,159 @@ pub fn prove_route_a_batched_time( &mut claim_is_dynamic, &mut claims, ); - - let wb_time_degree_bound = wb_time_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut wb_time_label: Option<&'static [u8]> = None; - let mut wb_time_oracle: Option> = wb_time_claim.map(|extra| { - wb_time_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = wb_time_oracle.as_deref_mut() { - // WB is a zero-identity stage: claimed sum is verifier-known and fixed to zero. - let claimed_sum = K::ZERO; - let label = wb_time_label.expect("missing wb_time label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let wp_time_degree_bound = wp_time_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut wp_time_label: Option<&'static [u8]> = None; - let mut wp_time_oracle: Option> = wp_time_claim.map(|extra| { - wp_time_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = wp_time_oracle.as_deref_mut() { - // WP is a zero-identity stage: claimed sum is verifier-known and fixed to zero. - let claimed_sum = K::ZERO; - let label = wp_time_label.expect("missing wp_time label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let decode_decode_fields_degree_bound = decode_decode_fields_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut decode_decode_fields_label: Option<&'static [u8]> = None; - let mut decode_decode_fields_oracle: Option> = decode_decode_fields_claim.map(|extra| { - decode_decode_fields_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = decode_decode_fields_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = decode_decode_fields_label.expect("missing decode_fields label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let decode_decode_immediates_degree_bound = decode_decode_immediates_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut decode_decode_immediates_label: Option<&'static [u8]> = None; - let mut decode_decode_immediates_oracle: Option> = - decode_decode_immediates_claim.map(|extra| { - decode_decode_immediates_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = decode_decode_immediates_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = decode_decode_immediates_label.expect("missing decode_immediates label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let width_bitness_degree_bound = width_bitness_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut width_bitness_label: Option<&'static [u8]> = None; - let mut width_bitness_oracle: Option> = width_bitness_claim.map(|extra| { - width_bitness_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = width_bitness_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = width_bitness_label.expect("missing width_bitness label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let width_quiescence_degree_bound = width_quiescence_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut width_quiescence_label: Option<&'static [u8]> = None; - let mut width_quiescence_oracle: Option> = width_quiescence_claim.map(|extra| { - width_quiescence_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = width_quiescence_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = width_quiescence_label.expect("missing width_quiescence label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let width_selector_linkage_degree_bound = width_selector_linkage_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut width_selector_linkage_label: Option<&'static [u8]> = None; - let mut width_selector_linkage_oracle: Option> = width_selector_linkage_claim.map(|extra| { - width_selector_linkage_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = width_selector_linkage_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = width_selector_linkage_label.expect("missing width_selector_linkage label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); + macro_rules! append_zero_optional_claim { + ($claim_opt:ident, $degree_bound:ident, $oracle:ident, $label:ident, $missing_label_msg:literal, $missing_claimed_sum_msg:literal) => { + let $degree_bound = $claim_opt.as_ref().map(|extra| extra.oracle.degree_bound()); + let (mut $oracle, $label, _claimed_sum) = split_extra_claim($claim_opt); + append_optional_claim( + &mut $oracle, + $label, + Some(K::ZERO), + false, + &mut claimed_sums, + &mut degree_bounds, + &mut labels, + &mut claim_is_dynamic, + &mut claims, + $missing_label_msg, + $missing_claimed_sum_msg, + ); + }; } - let width_load_semantics_degree_bound = width_load_semantics_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut width_load_semantics_label: Option<&'static [u8]> = None; - let mut width_load_semantics_oracle: Option> = width_load_semantics_claim.map(|extra| { - width_load_semantics_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = width_load_semantics_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = width_load_semantics_label.expect("missing width_load_semantics label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let width_store_semantics_degree_bound = width_store_semantics_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut width_store_semantics_label: Option<&'static [u8]> = None; - let mut width_store_semantics_oracle: Option> = width_store_semantics_claim.map(|extra| { - width_store_semantics_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = width_store_semantics_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = width_store_semantics_label.expect("missing width_store_semantics label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let control_next_pc_linear_degree_bound = control_next_pc_linear_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut control_next_pc_linear_label: Option<&'static [u8]> = None; - let mut control_next_pc_linear_oracle: Option> = control_next_pc_linear_claim.map(|extra| { - control_next_pc_linear_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = control_next_pc_linear_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = control_next_pc_linear_label.expect("missing control_next_pc_linear label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let control_next_pc_control_degree_bound = control_next_pc_control_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut control_next_pc_control_label: Option<&'static [u8]> = None; - let mut control_next_pc_control_oracle: Option> = control_next_pc_control_claim.map(|extra| { - control_next_pc_control_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = control_next_pc_control_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = control_next_pc_control_label.expect("missing control_next_pc_control label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let control_branch_semantics_degree_bound = control_branch_semantics_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut control_branch_semantics_label: Option<&'static [u8]> = None; - let mut control_branch_semantics_oracle: Option> = - control_branch_semantics_claim.map(|extra| { - control_branch_semantics_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = control_branch_semantics_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = control_branch_semantics_label.expect("missing control_branch_semantics label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); + macro_rules! append_dynamic_optional_claim { + ($claim_opt:ident, $degree_bound:ident, $oracle:ident, $label:ident, $claimed_sum:ident, $missing_label_msg:literal, $missing_claimed_sum_msg:literal) => { + let $degree_bound = $claim_opt.as_ref().map(|extra| extra.oracle.degree_bound()); + let (mut $oracle, $label, $claimed_sum) = split_extra_claim($claim_opt); + append_optional_claim( + &mut $oracle, + $label, + $claimed_sum, + true, + &mut claimed_sums, + &mut degree_bounds, + &mut labels, + &mut claim_is_dynamic, + &mut claims, + $missing_label_msg, + $missing_claimed_sum_msg, + ); + }; } - let control_control_writeback_degree_bound = control_control_writeback_claim - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut control_control_writeback_label: Option<&'static [u8]> = None; - let mut control_control_writeback_oracle: Option> = - control_control_writeback_claim.map(|extra| { - control_control_writeback_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = control_control_writeback_oracle.as_deref_mut() { - let claimed_sum = K::ZERO; - let label = control_control_writeback_label.expect("missing control_writeback label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(false); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } - - let ob_inc_total_degree_bound = ob_inc_total - .as_ref() - .map(|extra| extra.oracle.degree_bound()); - let mut ob_inc_total_claimed_sum: Option = None; - let mut ob_inc_total_label: Option<&'static [u8]> = None; - let mut ob_inc_total_oracle: Option> = ob_inc_total.map(|extra| { - ob_inc_total_claimed_sum = Some(extra.claimed_sum); - ob_inc_total_label = Some(extra.label); - extra.oracle - }); - if let Some(oracle) = ob_inc_total_oracle.as_deref_mut() { - let claimed_sum = ob_inc_total_claimed_sum.expect("missing ob_inc_total claimed_sum"); - let label = ob_inc_total_label.expect("missing ob_inc_total label"); - claimed_sums.push(claimed_sum); - degree_bounds.push(oracle.degree_bound()); - labels.push(label); - claim_is_dynamic.push(true); - claims.push(BatchedClaim { - oracle, - claimed_sum, - label, - }); - } + append_zero_optional_claim!( + wb_time_claim, + wb_time_degree_bound, + wb_time_oracle, + wb_time_label, + "missing wb_time label", + "missing wb_time claimed_sum" + ); + append_zero_optional_claim!( + wp_time_claim, + wp_time_degree_bound, + wp_time_oracle, + wp_time_label, + "missing wp_time label", + "missing wp_time claimed_sum" + ); + append_zero_optional_claim!( + decode_decode_fields_claim, + decode_decode_fields_degree_bound, + decode_decode_fields_oracle, + decode_decode_fields_label, + "missing decode_fields label", + "missing decode_fields claimed_sum" + ); + append_zero_optional_claim!( + decode_decode_immediates_claim, + decode_decode_immediates_degree_bound, + decode_decode_immediates_oracle, + decode_decode_immediates_label, + "missing decode_immediates label", + "missing decode_immediates claimed_sum" + ); + append_zero_optional_claim!( + width_bitness_claim, + width_bitness_degree_bound, + width_bitness_oracle, + width_bitness_label, + "missing width_bitness label", + "missing width_bitness claimed_sum" + ); + append_zero_optional_claim!( + width_quiescence_claim, + width_quiescence_degree_bound, + width_quiescence_oracle, + width_quiescence_label, + "missing width_quiescence label", + "missing width_quiescence claimed_sum" + ); + append_zero_optional_claim!( + width_selector_linkage_claim, + width_selector_linkage_degree_bound, + width_selector_linkage_oracle, + width_selector_linkage_label, + "missing width_selector_linkage label", + "missing width_selector_linkage claimed_sum" + ); + append_zero_optional_claim!( + width_load_semantics_claim, + width_load_semantics_degree_bound, + width_load_semantics_oracle, + width_load_semantics_label, + "missing width_load_semantics label", + "missing width_load_semantics claimed_sum" + ); + append_zero_optional_claim!( + width_store_semantics_claim, + width_store_semantics_degree_bound, + width_store_semantics_oracle, + width_store_semantics_label, + "missing width_store_semantics label", + "missing width_store_semantics claimed_sum" + ); + append_zero_optional_claim!( + control_next_pc_linear_claim, + control_next_pc_linear_degree_bound, + control_next_pc_linear_oracle, + control_next_pc_linear_label, + "missing control_next_pc_linear label", + "missing control_next_pc_linear claimed_sum" + ); + append_zero_optional_claim!( + control_next_pc_control_claim, + control_next_pc_control_degree_bound, + control_next_pc_control_oracle, + control_next_pc_control_label, + "missing control_next_pc_control label", + "missing control_next_pc_control claimed_sum" + ); + append_zero_optional_claim!( + control_branch_semantics_claim, + control_branch_semantics_degree_bound, + control_branch_semantics_oracle, + control_branch_semantics_label, + "missing control_branch_semantics label", + "missing control_branch_semantics claimed_sum" + ); + append_zero_optional_claim!( + control_control_writeback_claim, + control_control_writeback_degree_bound, + control_control_writeback_oracle, + control_control_writeback_label, + "missing control_writeback label", + "missing control_writeback claimed_sum" + ); + append_dynamic_optional_claim!( + ob_inc_total, + ob_inc_total_degree_bound, + ob_inc_total_oracle, + ob_inc_total_label, + ob_inc_total_claimed_sum, + "missing ob_inc_total label", + "missing ob_inc_total claimed_sum" + ); let metas = RouteATimeClaimPlan::time_claim_metas_for_instances( step.lut_instances.iter().map(|(inst, _)| inst), diff --git a/crates/neo-fold/tests/suites/integration/riscv_trace_ccs_diverse_programs.rs b/crates/neo-fold/tests/suites/integration/riscv_trace_ccs_diverse_programs.rs new file mode 100644 index 00000000..9d36ebda --- /dev/null +++ b/crates/neo-fold/tests/suites/integration/riscv_trace_ccs_diverse_programs.rs @@ -0,0 +1,257 @@ +//! CCS-level unit tests for the trace wiring CCS across diverse instruction types. +//! +//! These catch constraint design bugs by checking A*z . B*z = C*z on the raw CCS +//! without folding/proving, for every major instruction category. + +use neo_ccs::relations::check_ccs_rowwise_zero; +use neo_memory::riscv::ccs::{build_rv32_trace_wiring_ccs, rv32_trace_ccs_witness_from_exec_table, Rv32TraceCcsLayout}; +use neo_memory::riscv::exec_table::Rv32ExecTable; +use neo_memory::riscv::lookups::{ + decode_program, encode_program, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, RiscvOpcode, + RiscvShoutTables, PROG_ID, +}; +use neo_vm_trace::trace_program; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +fn trace_and_check_ccs(program: Vec, label: &str) { + let program_bytes = encode_program(&program); + let decoded = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(32); + cpu.load_program(0, decoded); + let twist = RiscvMemory::with_program_in_twist(32, PROG_ID, 0, &program_bytes); + let shout = RiscvShoutTables::new(32); + let trace = trace_program(cpu, twist, shout, 64).expect("trace_program"); + assert!(trace.did_halt(), "{label}: expected Halt"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, 4).expect("from_trace_padded_pow2"); + exec.validate_cycle_chain().unwrap_or_else(|e| panic!("{label}: {e}")); + exec.validate_pc_chain().unwrap_or_else(|e| panic!("{label}: {e}")); + exec.validate_halted_tail().unwrap_or_else(|e| panic!("{label}: {e}")); + exec.validate_inactive_rows_are_empty().unwrap_or_else(|e| panic!("{label}: {e}")); + + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("trace CCS layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("trace CCS witness"); + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + check_ccs_rowwise_zero(&ccs, &x, &w).unwrap_or_else(|e| panic!("{label}: CCS not satisfied: {e:?}")); +} + +fn trace_and_get_witness(program: Vec) -> (Rv32TraceCcsLayout, Vec, Vec) { + let program_bytes = encode_program(&program); + let decoded = decode_program(&program_bytes).expect("decode_program"); + let mut cpu = RiscvCpu::new(32); + cpu.load_program(0, decoded); + let twist = RiscvMemory::with_program_in_twist(32, PROG_ID, 0, &program_bytes); + let shout = RiscvShoutTables::new(32); + let trace = trace_program(cpu, twist, shout, 64).expect("trace_program"); + + let exec = Rv32ExecTable::from_trace_padded_pow2(&trace, 4).expect("from_trace_padded_pow2"); + let layout = Rv32TraceCcsLayout::new(exec.rows.len()).expect("layout"); + let (x, w) = rv32_trace_ccs_witness_from_exec_table(&layout, &exec).expect("witness"); + (layout, x, w) +} + +// ── Happy-path CCS satisfaction for diverse instruction types ── + +#[test] +fn trace_ccs_happy_rv32i_alu_ops() { + trace_and_check_ccs( + vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 5 }, + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 2, rs1: 0, imm: 3 }, + RiscvInstruction::RAlu { op: RiscvOpcode::Add, rd: 3, rs1: 1, rs2: 2 }, + RiscvInstruction::RAlu { op: RiscvOpcode::Sub, rd: 4, rs1: 1, rs2: 2 }, + RiscvInstruction::RAlu { op: RiscvOpcode::And, rd: 5, rs1: 1, rs2: 2 }, + RiscvInstruction::RAlu { op: RiscvOpcode::Or, rd: 6, rs1: 1, rs2: 2 }, + RiscvInstruction::RAlu { op: RiscvOpcode::Xor, rd: 7, rs1: 1, rs2: 2 }, + RiscvInstruction::RAlu { op: RiscvOpcode::Slt, rd: 8, rs1: 2, rs2: 1 }, + RiscvInstruction::RAlu { op: RiscvOpcode::Sltu, rd: 9, rs1: 2, rs2: 1 }, + RiscvInstruction::Halt, + ], + "rv32i_alu_ops", + ); +} + +#[test] +fn trace_ccs_happy_shifts() { + trace_and_check_ccs( + vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 0x0F }, + RiscvInstruction::IAlu { op: RiscvOpcode::Sll, rd: 2, rs1: 1, imm: 4 }, + RiscvInstruction::IAlu { op: RiscvOpcode::Srl, rd: 3, rs1: 2, imm: 2 }, + RiscvInstruction::IAlu { op: RiscvOpcode::Sra, rd: 4, rs1: 2, imm: 2 }, + RiscvInstruction::Halt, + ], + "shifts", + ); +} + +#[test] +fn trace_ccs_happy_load_store_word() { + trace_and_check_ccs( + vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 42 }, + RiscvInstruction::Store { op: RiscvMemOp::Sw, rs1: 0, rs2: 1, imm: 0x100 }, + RiscvInstruction::Load { op: RiscvMemOp::Lw, rd: 2, rs1: 0, imm: 0x100 }, + RiscvInstruction::Halt, + ], + "load_store_word", + ); +} + +#[test] +fn trace_ccs_happy_byte_half_loads() { + trace_and_check_ccs( + vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 0x1FF }, + RiscvInstruction::Store { op: RiscvMemOp::Sw, rs1: 0, rs2: 1, imm: 0x100 }, + RiscvInstruction::Load { op: RiscvMemOp::Lb, rd: 2, rs1: 0, imm: 0x100 }, + RiscvInstruction::Load { op: RiscvMemOp::Lbu, rd: 3, rs1: 0, imm: 0x100 }, + RiscvInstruction::Load { op: RiscvMemOp::Lh, rd: 4, rs1: 0, imm: 0x100 }, + RiscvInstruction::Load { op: RiscvMemOp::Lhu, rd: 5, rs1: 0, imm: 0x100 }, + RiscvInstruction::Halt, + ], + "byte_half_loads", + ); +} + +#[test] +fn trace_ccs_happy_sub_word_stores() { + trace_and_check_ccs( + vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 0xAB }, + RiscvInstruction::Store { op: RiscvMemOp::Sb, rs1: 0, rs2: 1, imm: 0x100 }, + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 2, rs1: 0, imm: 0x1234 }, + RiscvInstruction::Store { op: RiscvMemOp::Sh, rs1: 0, rs2: 2, imm: 0x104 }, + RiscvInstruction::Halt, + ], + "sub_word_stores", + ); +} + +#[test] +fn trace_ccs_happy_branches_beq_bne() { + trace_and_check_ccs( + vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 5 }, + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 2, rs1: 0, imm: 5 }, + RiscvInstruction::Branch { cond: BranchCondition::Eq, rs1: 1, rs2: 2, imm: 8 }, + RiscvInstruction::Halt, + RiscvInstruction::Branch { cond: BranchCondition::Ne, rs1: 1, rs2: 0, imm: 8 }, + RiscvInstruction::Halt, + RiscvInstruction::Halt, + ], + "branches_beq_bne", + ); +} + +#[test] +fn trace_ccs_happy_branches_blt_bge() { + trace_and_check_ccs( + vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 3 }, + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 2, rs1: 0, imm: 5 }, + RiscvInstruction::Branch { cond: BranchCondition::Lt, rs1: 1, rs2: 2, imm: 8 }, + RiscvInstruction::Halt, + RiscvInstruction::Branch { cond: BranchCondition::Geu, rs1: 2, rs2: 1, imm: 8 }, + RiscvInstruction::Halt, + RiscvInstruction::Halt, + ], + "branches_blt_bge", + ); +} + +#[test] +fn trace_ccs_happy_jal_jalr() { + trace_and_check_ccs( + vec![ + RiscvInstruction::Jal { rd: 1, imm: 8 }, + RiscvInstruction::Halt, + RiscvInstruction::Jalr { rd: 2, rs1: 1, imm: 0 }, + RiscvInstruction::Halt, + ], + "jal_jalr", + ); +} + +#[test] +fn trace_ccs_happy_lui_auipc() { + trace_and_check_ccs( + vec![ + RiscvInstruction::Lui { rd: 1, imm: 0x12 }, + RiscvInstruction::Auipc { rd: 2, imm: 0 }, + RiscvInstruction::Halt, + ], + "lui_auipc", + ); +} + +#[test] +fn trace_ccs_happy_fence() { + trace_and_check_ccs( + vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 1 }, + RiscvInstruction::Fence { pred: 0xF, succ: 0xF }, + RiscvInstruction::Halt, + ], + "fence", + ); +} + +// ── CCS-level tamper tests (detect constraint violations without folding) ── + +#[test] +fn trace_ccs_rejects_tampered_pc_after() { + let program = vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 1 }, + RiscvInstruction::Halt, + ]; + let (layout, x, mut w) = trace_and_get_witness(program); + + let idx = layout.trace.pc_after * layout.t + 0; + w[idx] += F::ONE; + + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let res = check_ccs_rowwise_zero(&ccs, &x, &w); + assert!(res.is_err(), "tampered pc_after should violate CCS"); +} + +// NOTE: shout_val is NOT constrained by the trace wiring CCS alone. +// It is bound through the Shout bus (Route-A claim). The full-pipeline +// tamper test lives in trace_bus_binding_redteam::trace_cpu_vs_bus_shout_val_mismatch_must_fail. + +#[test] +fn trace_ccs_rejects_tampered_halted_flag() { + let program = vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 1 }, + RiscvInstruction::Halt, + ]; + let (layout, x, mut w) = trace_and_get_witness(program); + + let idx = layout.trace.halted * layout.t + 0; + w[idx] += F::ONE; + + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let res = check_ccs_rowwise_zero(&ccs, &x, &w); + assert!(res.is_err(), "tampered halted flag should violate CCS"); +} + +// NOTE: rd_val is NOT constrained by the trace wiring CCS alone. +// It is bound through the Twist bus (memory sidecar). The full-pipeline +// tamper test for memory values lives in trace_bus_binding_redteam. + +#[test] +fn trace_ccs_rejects_tampered_cycle() { + let program = vec![ + RiscvInstruction::IAlu { op: RiscvOpcode::Add, rd: 1, rs1: 0, imm: 1 }, + RiscvInstruction::Halt, + ]; + let (layout, x, mut w) = trace_and_get_witness(program); + + let idx = layout.trace.cycle * layout.t + 0; + w[idx] += F::ONE; + + let ccs = build_rv32_trace_wiring_ccs(&layout).expect("trace CCS"); + let res = check_ccs_rowwise_zero(&ccs, &x, &w); + assert!(res.is_err(), "tampered cycle counter should violate CCS"); +} diff --git a/crates/neo-fold/tests/suites/shared_bus/trace_bus_binding_redteam.rs b/crates/neo-fold/tests/suites/shared_bus/trace_bus_binding_redteam.rs new file mode 100644 index 00000000..cf035ce3 --- /dev/null +++ b/crates/neo-fold/tests/suites/shared_bus/trace_bus_binding_redteam.rs @@ -0,0 +1,113 @@ +use neo_ajtai::AjtaiSModule; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::session::FoldingSession; +use neo_math::K; +use neo_memory::ajtai::encode_vector_balanced_to_mat; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as F; + +type StepWit = neo_memory::witness::StepWitnessBundle; + +fn prove_run(program: Vec, max_steps: usize) -> Rv32TraceWiringRun { + let program_bytes = encode_program(&program); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .max_steps(max_steps) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + run +} + +fn step_bundle_recommit_after_private_tamper( + params: &neo_params::NeoParams, + committer: &AjtaiSModule, + step: &mut StepWit, + idx_to_tamper: usize, + delta: F, +) { + let (ref mut inst, ref mut wit) = step.mcs; + let m_in = inst.m_in; + assert!( + idx_to_tamper >= m_in, + "expected idx_to_tamper to be in the private witness region (idx={idx_to_tamper}, m_in={m_in})" + ); + + let mut z = Vec::with_capacity(m_in + wit.w.len()); + z.extend_from_slice(&inst.x); + z.extend_from_slice(&wit.w); + assert!(idx_to_tamper < z.len(), "idx_to_tamper out of range"); + + z[idx_to_tamper] += delta; + wit.w = z[m_in..].to_vec(); + wit.Z = encode_vector_balanced_to_mat(params, &z); + inst.c = neo_ccs::traits::SModuleHomomorphism::commit(committer, &wit.Z); +} + +fn prove_main_shard_proof_or_verify_fails(run: &Rv32TraceWiringRun, steps_bad: Vec) { + let mut sess = FoldingSession::new(FoldingMode::Optimized, run.params().clone(), run.committer().clone()); + sess.set_step_linking(run.step_linking_config()); + sess.add_step_bundles(steps_bad); + + let Ok(proof_bad) = sess.fold_and_prove(run.ccs()) else { + return; + }; + let res = sess.verify_collected(run.ccs(), &proof_bad); + assert!( + matches!(res, Err(_) | Ok(false)), + "malicious main proof unexpectedly verified" + ); +} + +#[test] +fn trace_cpu_vs_bus_twist_rv_mismatch_must_fail() { + let program = vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let run = prove_run(program, 2); + + let layout = run.layout(); + let t = layout.t; + let trace = &layout.trace; + + let idx_ram_rv = layout.trace_base + trace.ram_rv * t + 0; + + let mut steps_bad: Vec = run.steps_witness().to_vec(); + step_bundle_recommit_after_private_tamper(run.params(), run.committer(), &mut steps_bad[0], idx_ram_rv, F::ONE); + + prove_main_shard_proof_or_verify_fails(&run, steps_bad); +} + +#[test] +fn trace_cpu_vs_bus_shout_val_mismatch_must_fail() { + let run = prove_run( + vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ], + 2, + ); + + let layout = run.layout(); + let t = layout.t; + let trace = &layout.trace; + + let idx_shout_val = layout.trace_base + trace.shout_val * t + 0; + + let mut steps_bad: Vec = run.steps_witness().to_vec(); + step_bundle_recommit_after_private_tamper(run.params(), run.committer(), &mut steps_bad[0], idx_shout_val, F::ONE); + + prove_main_shard_proof_or_verify_fails(&run, steps_bad); +} diff --git a/crates/neo-fold/tests/suites/shared_bus/trace_main_proof_redteam.rs b/crates/neo-fold/tests/suites/shared_bus/trace_main_proof_redteam.rs new file mode 100644 index 00000000..de9a674a --- /dev/null +++ b/crates/neo-fold/tests/suites/shared_bus/trace_main_proof_redteam.rs @@ -0,0 +1,137 @@ +use neo_ajtai::AjtaiSModule; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::session::FoldingSession; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode, PROG_ID, REG_ID}; +use neo_memory::MemInit; +use p3_goldilocks::Goldilocks as F; + +type StepWit = neo_memory::witness::StepWitnessBundle; + +fn addi_halt_program_bytes(imm: i32) -> Vec { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm, + }, + RiscvInstruction::Halt, + ]; + encode_program(&program) +} + +fn mem_idx(steps: &[StepWit], mem_id: u32) -> usize { + steps[0] + .mem_instances + .iter() + .position(|(inst, _)| inst.mem_id == mem_id) + .unwrap_or_else(|| panic!("missing mem_id={mem_id} in step mem_instances")) +} + +fn verifier_only_session_for_steps(run: &Rv32TraceWiringRun, steps: Vec) -> FoldingSession { + let mut sess = FoldingSession::new(FoldingMode::Optimized, run.params().clone(), run.committer().clone()); + sess.set_step_linking(run.step_linking_config()); + sess.add_step_bundles(steps); + sess +} + +#[test] +fn trace_main_proof_truncated_steps_must_fail() { + let program_bytes = addi_halt_program_bytes(1); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .max_steps(2) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + // Baseline: full step set verifies. + let steps_ok: Vec = run.steps_witness().to_vec(); + let sess_ok = verifier_only_session_for_steps(&run, steps_ok); + assert_eq!( + sess_ok.verify_collected(run.ccs(), run.proof()).expect("main proof verify"), + true + ); + + // Trace mode: entire trace = single step bundle, so the meaningful + // truncation attack is providing ZERO steps. + let steps_bad: Vec = Vec::new(); + let sess_bad = verifier_only_session_for_steps(&run, steps_bad); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); + assert!(matches!(res, Err(_) | Ok(false)), "zero steps must not verify"); +} + +#[test] +fn trace_main_proof_tamper_prog_init_must_fail() { + let program_bytes = addi_halt_program_bytes(1); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .max_steps(2) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut steps_bad: Vec = run.steps_witness().to_vec(); + let prog_idx = mem_idx(&steps_bad, PROG_ID.0); + steps_bad[0].mem_instances[prog_idx].0.init = MemInit::Zero; + + let sess_bad = verifier_only_session_for_steps(&run, steps_bad); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); + assert!( + matches!(res, Err(_) | Ok(false)), + "tampering PROG Twist init in public input must fail verification" + ); +} + +#[test] +fn trace_main_proof_tamper_reg_init_must_fail() { + let program_bytes = addi_halt_program_bytes(1); + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .max_steps(2) + .reg_init_u32(2, 7) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut steps_bad: Vec = run.steps_witness().to_vec(); + let reg_idx = mem_idx(&steps_bad, REG_ID.0); + steps_bad[0].mem_instances[reg_idx].0.init = MemInit::Zero; + + let sess_bad = verifier_only_session_for_steps(&run, steps_bad); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); + assert!( + matches!(res, Err(_) | Ok(false)), + "tampering REG Twist init in public input must fail verification" + ); +} + +// NOTE: Step reordering is not applicable in trace mode because the entire +// trace is folded as a single step bundle. The intra-step state chain +// (PC chain, cycle chain) is enforced by the trace wiring CCS constraints, +// verified in riscv_trace_ccs_diverse_programs::trace_ccs_rejects_tampered_pc_after +// and trace_ccs_rejects_tampered_cycle. + +#[test] +fn trace_main_proof_splicing_across_runs_must_fail() { + let program_bytes_a = addi_halt_program_bytes(1); + let mut run_a = Rv32TraceWiring::from_rom(0, &program_bytes_a) + .max_steps(2) + .prove() + .expect("prove A"); + run_a.verify().expect("baseline verify A"); + + let program_bytes_b = addi_halt_program_bytes(2); + let mut run_b = Rv32TraceWiring::from_rom(0, &program_bytes_b) + .max_steps(2) + .prove() + .expect("prove B"); + run_b.verify().expect("baseline verify B"); + + let steps_bad: Vec = run_b.steps_witness().to_vec(); + let sess_bad = verifier_only_session_for_steps(&run_a, steps_bad); + let res = sess_bad.verify_collected(run_a.ccs(), run_a.proof()); + assert!( + matches!(res, Err(_) | Ok(false)), + "splicing main proof across runs must not verify" + ); +} diff --git a/crates/neo-fold/tests/suites/shared_bus/trace_twist_shout_redteam.rs b/crates/neo-fold/tests/suites/shared_bus/trace_twist_shout_redteam.rs new file mode 100644 index 00000000..3c66230e --- /dev/null +++ b/crates/neo-fold/tests/suites/shared_bus/trace_twist_shout_redteam.rs @@ -0,0 +1,172 @@ +use neo_ajtai::AjtaiSModule; +use neo_fold::pi_ccs::FoldingMode; +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_fold::session::FoldingSession; +use neo_math::K; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvMemOp, RiscvOpcode, RAM_ID}; +use neo_memory::witness::LutTableSpec; +use neo_memory::MemInit; +use p3_goldilocks::Goldilocks as F; + +type StepWit = neo_memory::witness::StepWitnessBundle; + +fn verifier_only_session_for_steps(run: &Rv32TraceWiringRun, steps: Vec) -> FoldingSession { + let mut sess = FoldingSession::new(FoldingMode::Optimized, run.params().clone(), run.committer().clone()); + sess.set_step_linking(run.step_linking_config()); + sess.add_step_bundles(steps); + sess +} + +#[test] +fn trace_twist_instances_reordered_must_fail() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .max_steps(2) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut steps_bad: Vec = run.steps_witness().to_vec(); + assert!( + steps_bad[0].mem_instances.len() >= 2, + "expected at least 2 Twist instances" + ); + steps_bad[0].mem_instances.swap(0, 1); + + let sess_bad = verifier_only_session_for_steps(&run, steps_bad); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); + assert!( + matches!(res, Err(_) | Ok(false)), + "reordering Twist instances must not verify" + ); +} + +#[test] +fn trace_shout_table_spec_tamper_must_fail() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .max_steps(2) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut steps_bad: Vec = run.steps_witness().to_vec(); + assert!( + !steps_bad[0].lut_instances.is_empty(), + "expected at least 1 Shout instance" + ); + let lut_inst = &mut steps_bad[0].lut_instances[0].0; + assert!( + matches!(&lut_inst.table_spec, Some(LutTableSpec::RiscvOpcode { .. })), + "expected a virtual RISC-V opcode table (table_spec=Some)" + ); + lut_inst.table_spec = Some(LutTableSpec::RiscvOpcode { + opcode: RiscvOpcode::Xor, + xlen: 32, + }); + + let sess_bad = verifier_only_session_for_steps(&run, steps_bad); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); + assert!( + matches!(res, Err(_) | Ok(false)), + "tampering Shout table_spec must not verify" + ); +} + +#[test] +fn trace_shout_instances_reordered_must_fail() { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: 1, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Or, + rd: 2, + rs1: 1, + imm: 3, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .max_steps(4) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let mut steps_bad: Vec = run.steps_witness().to_vec(); + assert!( + steps_bad[0].lut_instances.len() >= 2, + "expected at least 2 Shout instances for ADDI+ORI program" + ); + steps_bad[0].lut_instances.swap(0, 1); + + let sess_bad = verifier_only_session_for_steps(&run, steps_bad); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); + assert!( + matches!(res, Err(_) | Ok(false)), + "reordering Shout instances must not verify" + ); +} + +#[test] +fn trace_ram_init_statement_tamper_must_fail() { + let program = vec![ + RiscvInstruction::Load { + op: RiscvMemOp::Lw, + rd: 1, + rs1: 0, + imm: 0, + }, + RiscvInstruction::Halt, + ]; + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .max_steps(2) + .ram_init_u32(0, 7) + .prove() + .expect("prove"); + run.verify().expect("baseline verify"); + + let ram_idx = run + .steps_witness()[0] + .mem_instances + .iter() + .position(|(inst, _)| inst.mem_id == RAM_ID.0) + .expect("missing RAM Twist instance"); + + let mut steps_bad: Vec = run.steps_witness().to_vec(); + steps_bad[0].mem_instances[ram_idx].0.init = MemInit::Zero; + + let sess_bad = verifier_only_session_for_steps(&run, steps_bad); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); + assert!( + matches!(res, Err(_) | Ok(false)), + "tampering RAM Twist init in public input must fail verification" + ); +} diff --git a/crates/neo-memory/src/builder.rs b/crates/neo-memory/src/builder.rs index 4898df9d..7ebcc9e1 100644 --- a/crates/neo-memory/src/builder.rs +++ b/crates/neo-memory/src/builder.rs @@ -80,6 +80,19 @@ fn ell_from_pow2_n_side(n_side: usize) -> Result { Ok(n_side.trailing_zeros() as usize) } +fn validate_chunk_size(chunk_size: usize) -> Result<(), ShardBuildError> { + if chunk_size == 0 { + return Err(ShardBuildError::InvalidChunkSize("chunk_size must be >= 1".into())); + } + Ok(()) +} + +fn bundles_only( + out: Result<(Vec>, ShardWitnessAux), ShardBuildError>, +) -> Result>, ShardBuildError> { + out.map(|(bundles, _aux)| bundles) +} + /// Build shard witness bundles for **shared CPU bus** mode. /// /// In this mode Twist/Shout access-row columns are expected to live in the CPU witness `z` @@ -108,7 +121,7 @@ where Sh: neo_vm_trace::Shout, A: CpuArithmetization, { - let (bundles, _aux) = build_shard_witness_shared_cpu_bus_with_aux( + bundles_only(build_shard_witness_shared_cpu_bus_with_aux( vm, twist, shout, @@ -120,8 +133,7 @@ where lut_lanes, initial_mem, cpu_arith, - )?; - Ok(bundles) + )) } /// Build shard witness bundles for **shared CPU bus** mode from an already-executed VM trace. @@ -142,7 +154,7 @@ pub fn build_shard_witness_shared_cpu_bus_from_trace( where A: CpuArithmetization, { - let (bundles, _aux) = build_shard_witness_shared_cpu_bus_from_trace_with_aux( + bundles_only(build_shard_witness_shared_cpu_bus_from_trace_with_aux( trace, max_steps, chunk_size, @@ -152,8 +164,7 @@ where lut_lanes, initial_mem, cpu_arith, - )?; - Ok(bundles) + )) } /// Like `build_shard_witness_shared_cpu_bus_from_trace`, but also returns auxiliary outputs useful @@ -172,9 +183,7 @@ pub fn build_shard_witness_shared_cpu_bus_from_trace_with_aux( where A: CpuArithmetization, { - if chunk_size == 0 { - return Err(ShardBuildError::InvalidChunkSize("chunk_size must be >= 1".into())); - } + validate_chunk_size(chunk_size)?; if trace.steps.len() > max_steps { return Err(ShardBuildError::InvalidChunkSize(format!( "trace length {} exceeds max_steps {}", @@ -443,9 +452,7 @@ where Sh: neo_vm_trace::Shout, A: CpuArithmetization, { - if chunk_size == 0 { - return Err(ShardBuildError::InvalidChunkSize("chunk_size must be >= 1".into())); - } + validate_chunk_size(chunk_size)?; // 1) Run VM and collect the executed trace for this shard (up to `max_steps`). // diff --git a/crates/neo-memory/src/cpu/constraints.rs b/crates/neo-memory/src/cpu/constraints.rs index 73ce1c3f..ec6f7b75 100644 --- a/crates/neo-memory/src/cpu/constraints.rs +++ b/crates/neo-memory/src/cpu/constraints.rs @@ -323,6 +323,15 @@ pub struct CpuConstraintBuilder { constraints: Vec>, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ShoutPaddingMode { + Full, + WithoutSelectorBitness, + ValueOnly, + ValuePaddingOnly, + AddrBitBitnessOnly, +} + impl CpuConstraintBuilder { /// Create a new constraint builder. /// @@ -600,86 +609,71 @@ impl CpuConstraintBuilder { /// Add Shout selector/bitness/padding constraints only (no CPU linkage). pub fn add_shout_instance_padding(&mut self, layout: &BusLayout, shout: &ShoutCols) { - for j in 0..layout.chunk_size { - let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); - let bus_val = layout.bus_cell(shout.primary_val(), j); - - // Ensure bus selector is boolean so gated-bit constraints imply true {0,1} bitness. - self.add_boolean_constraint(CpuConstraintLabel::ShoutHasLookupBoolean, bus_has_lookup); - - // Padding: (1 - has_lookup) * val = 0 - self.constraints.push(CpuConstraint::new_zero_negated( - CpuConstraintLabel::LookupValueZeroPadding, - bus_has_lookup, - bus_val, - )); - - // Lookup key bits: - // - Bitness: bit is 0 when inactive, boolean when active - for col_id in shout.addr_bits.clone() { - let bit = layout.bus_cell(col_id, j); - self.add_gated_bit_constraint(CpuConstraintLabel::ShoutAddrBitBitness, bit, bus_has_lookup); - } - } + self.add_shout_instance_padding_mode(layout, shout, ShoutPaddingMode::Full); } /// Add Shout addr-bit/value padding constraints without selector booleanity. pub fn add_shout_instance_padding_without_selector_bitness(&mut self, layout: &BusLayout, shout: &ShoutCols) { - for j in 0..layout.chunk_size { - let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); - let bus_val = layout.bus_cell(shout.primary_val(), j); - - // Padding: (1 - has_lookup) * val = 0 - self.constraints.push(CpuConstraint::new_zero_negated( - CpuConstraintLabel::LookupValueZeroPadding, - bus_has_lookup, - bus_val, - )); - - // Lookup key bits: - // - Bitness: bit is 0 when inactive, boolean when active - for col_id in shout.addr_bits.clone() { - let bit = layout.bus_cell(col_id, j); - self.add_gated_bit_constraint(CpuConstraintLabel::ShoutAddrBitBitness, bit, bus_has_lookup); - } - } + self.add_shout_instance_padding_mode(layout, shout, ShoutPaddingMode::WithoutSelectorBitness); } /// Add Shout selector/value padding only (no addr-bit constraints). pub fn add_shout_instance_padding_value_only(&mut self, layout: &BusLayout, shout: &ShoutCols) { - for j in 0..layout.chunk_size { - let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); - let bus_val = layout.bus_cell(shout.primary_val(), j); - - self.add_boolean_constraint(CpuConstraintLabel::ShoutHasLookupBoolean, bus_has_lookup); - self.constraints.push(CpuConstraint::new_zero_negated( - CpuConstraintLabel::LookupValueZeroPadding, - bus_has_lookup, - bus_val, - )); - } + self.add_shout_instance_padding_mode(layout, shout, ShoutPaddingMode::ValueOnly); } /// Add Shout value padding only (no selector booleanity, no addr-bit constraints). pub fn add_shout_instance_value_padding_only(&mut self, layout: &BusLayout, shout: &ShoutCols) { - for j in 0..layout.chunk_size { - let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); - let bus_val = layout.bus_cell(shout.primary_val(), j); - - self.constraints.push(CpuConstraint::new_zero_negated( - CpuConstraintLabel::LookupValueZeroPadding, - bus_has_lookup, - bus_val, - )); - } + self.add_shout_instance_padding_mode(layout, shout, ShoutPaddingMode::ValuePaddingOnly); } /// Add unconditional addr-bit booleanity constraints for one shared-address Shout group. pub fn add_shout_instance_addr_bit_bitness(&mut self, layout: &BusLayout, shout: &ShoutCols) { + self.add_shout_instance_padding_mode(layout, shout, ShoutPaddingMode::AddrBitBitnessOnly); + } + + fn add_shout_instance_padding_mode(&mut self, layout: &BusLayout, shout: &ShoutCols, mode: ShoutPaddingMode) { + let add_selector_bitness = matches!(mode, ShoutPaddingMode::Full | ShoutPaddingMode::ValueOnly); + let add_value_padding = matches!( + mode, + ShoutPaddingMode::Full + | ShoutPaddingMode::WithoutSelectorBitness + | ShoutPaddingMode::ValueOnly + | ShoutPaddingMode::ValuePaddingOnly + ); + let add_gated_addr_bitness = matches!(mode, ShoutPaddingMode::Full | ShoutPaddingMode::WithoutSelectorBitness); + let add_unconditional_addr_bitness = matches!(mode, ShoutPaddingMode::AddrBitBitnessOnly); + for j in 0..layout.chunk_size { - for col_id in shout.addr_bits.clone() { - let bit = layout.bus_cell(col_id, j); - self.add_boolean_constraint(CpuConstraintLabel::ShoutAddrBitBitness, bit); + let bus_has_lookup = layout.bus_cell(shout.has_lookup, j); + + if add_selector_bitness { + self.add_boolean_constraint(CpuConstraintLabel::ShoutHasLookupBoolean, bus_has_lookup); + } + + if add_value_padding { + let bus_val = layout.bus_cell(shout.primary_val(), j); + self.constraints.push(CpuConstraint::new_zero_negated( + CpuConstraintLabel::LookupValueZeroPadding, + bus_has_lookup, + bus_val, + )); + } + + if add_gated_addr_bitness { + // Bitness: bit is 0 when inactive, boolean when active. + for col_id in shout.addr_bits.clone() { + let bit = layout.bus_cell(col_id, j); + self.add_gated_bit_constraint(CpuConstraintLabel::ShoutAddrBitBitness, bit, bus_has_lookup); + } + } + + if add_unconditional_addr_bitness { + // Unconditionally enforce bit ∈ {0,1}. + for col_id in shout.addr_bits.clone() { + let bit = layout.bus_cell(col_id, j); + self.add_boolean_constraint(CpuConstraintLabel::ShoutAddrBitBitness, bit); + } } } } diff --git a/crates/neo-memory/src/twist_oracle.rs b/crates/neo-memory/src/twist_oracle.rs index 20e4a17a..7adcdca7 100644 --- a/crates/neo-memory/src/twist_oracle.rs +++ b/crates/neo-memory/src/twist_oracle.rs @@ -1,17 +1,3 @@ -//! Twist (and Shout) sumcheck oracles built from multilinear factor tables. -//! -//! This module provides oracles for the **index-bit addressing** architecture: -//! instead of materializing huge one-hot tables, we compute eq(bits, r_addr) -//! dynamically from committed bit columns. -//! -//! ## Key Oracles -//! -//! - `ProductRoundOracle`: Generic multilinear product sumcheck (used for address-domain checks) -//! - `AddressLookupOracle`: Shout address-domain lookup sumcheck -//! - `IndexAdapterOracleSparseTime`: Sparse-in-time IDX→OH adapter (time-domain) -//! - `LazyWeightedBitnessOracleSparseTime`: Sparse-in-time aggregated bitness checks (time-domain) -//! - `TwistReadCheckOracleSparseTime` / `TwistWriteCheckOracleSparseTime`: Twist time-lane checks (time-domain) -//! - `TwistValEvalOracleSparseTime` / `TwistTotalIncOracleSparseTime`: Twist val reconstruction (time-domain) use crate::bit_ops::eq_bit_affine; use crate::mle::{eq_single, lt_eval}; @@ -40,14 +26,6 @@ macro_rules! impl_round_oracle_via_core { }; } -// ============================================================================ -// Core ProductRoundOracle -// ============================================================================ - -/// Helper that runs sumcheck for a product of multilinear factors. -/// -/// Each factor table is a length-2^ℓ vector enumerating the factor on the -/// Boolean hypercube, using little-endian bit order. pub struct ProductRoundOracle { factors: Vec>, rounds_remaining: usize, @@ -104,7 +82,6 @@ impl ProductRoundOracle { impl RoundOracle for ProductRoundOracle { fn evals_at(&mut self, points: &[K]) -> Vec { if self.rounds_remaining == 0 { - // Return the single value for all points let val = self .value() .expect("ProductRoundOracle invariant broken: rounds_remaining==0 but value() is None"); @@ -155,23 +132,13 @@ impl RoundOracle for ProductRoundOracle { } } -// ============================================================================ -// Sparse-in-time helpers and oracles (Track A) -// ============================================================================ - -#[inline] fn chi_at_bool_index(r: &[K], idx: usize) -> K { crate::mle::chi_at_index(r, idx) } -/// Compute χ_{r_cycle}(t) children for the current time sumcheck round. -/// -/// Variable order is little-endian (bit 0 first), matching `ProductRoundOracle`. -#[inline] fn chi_cycle_children(r_cycle: &[K], bit_idx: usize, prefix_eq: K, pair_idx: usize) -> (K, K) { debug_assert!(bit_idx < r_cycle.len()); - // Higher bits (bit_idx+1..ell) come from pair_idx, little-endian. let mut suffix = K::ONE; let mut shift = 0usize; for b in (bit_idx + 1)..r_cycle.len() { @@ -186,29 +153,26 @@ fn chi_cycle_children(r_cycle: &[K], bit_idx: usize, prefix_eq: K, pair_idx: usi (child0, child1) } -fn gather_pairs_from_sparse(entries: &[(usize, K)]) -> Vec { - let mut out: Vec = Vec::with_capacity(entries.len()); - let mut prev: Option = None; - for &(idx, _v) in entries { - let p = idx >> 1; - if prev != Some(p) { - out.push(p); - prev = Some(p); +macro_rules! for_each_sparse_parent_pair { + ($entries:expr, $pair:ident, $body:block) => {{ + let mut prev_pair = usize::MAX; + for &(idx, _) in $entries { + let $pair = idx >> 1; + if $pair == prev_pair { + continue; + } + prev_pair = $pair; + $body } - } - out + }}; +} + +fn expr_id1(cols: &[K; 1]) -> K { + cols[0] } -/// Sparse Route A oracle for Shout value: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · val(t) pub struct ShoutValueOracleSparse { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - val: SparseIdxVec, - degree_bound: usize, - challenges: Vec, + core: SparseTimeExprOracle<1>, } impl ShoutValueOracleSparse { @@ -228,40 +192,62 @@ impl ShoutValueOracleSparse { ( Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - val, - degree_bound: 3, - challenges: Vec::with_capacity(ell_n), + core: SparseTimeExprOracle::new(r_cycle, has_lookup, [val], 3, expr_id1), }, claim, ) } } +impl_round_oracle_via_core!(ShoutValueOracleSparse); + +type SparseExprFn = fn(&[K; N]) -> K; + +struct SparseTimeExprOracle { + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + has_lookup: SparseIdxVec, + cols: [SparseIdxVec; N], + degree_bound: usize, + expr_fn: SparseExprFn, +} + +impl SparseTimeExprOracle { + fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + cols: [SparseIdxVec; N], + degree_bound: usize, + expr_fn: SparseExprFn, + ) -> Self { + let ell_n = r_cycle.len(); + debug_assert_eq!(has_lookup.len(), 1usize << ell_n); + assert_cols_match_time(&cols, 1usize << ell_n); + + Self { + bit_idx: 0, + r_cycle: r_cycle.to_vec(), + prefix_eq: K::ONE, + has_lookup, + cols, + degree_bound, + expr_fn, + } + } +} -impl RoundOracle for ShoutValueOracleSparse { +impl RoundOracle for SparseTimeExprOracle { fn evals_at(&mut self, points: &[K]) -> Vec { if self.has_lookup.len() == 1 { - let v = self.prefix_eq * self.has_lookup.singleton_value() * self.val.singleton_value(); + let gate = self.has_lookup.singleton_value(); + let cols = std::array::from_fn(|i| self.cols[i].singleton_value()); + let expr = (self.expr_fn)(&cols); + let v = self.prefix_eq * gate * expr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let (chi0_base, chi1_base) = if self.bit_idx < self.r_cycle.len() { - // Per-pair child weights depend on higher-bit assignment (pair index). - (K::ZERO, K::ZERO) - } else { - (K::ZERO, K::ZERO) - }; - let _ = (chi0_base, chi1_base); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -271,18 +257,28 @@ impl RoundOracle for ShoutValueOracleSparse { continue; } - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); + let cols0: [K; N] = std::array::from_fn(|i| self.cols[i].get(child0)); + let cols1: [K; N] = std::array::from_fn(|i| self.cols[i].get(child1)); let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); for (i, &x) in points.iter().enumerate() { - let chi_x = chi0 + (chi1 - chi0) * x; - let gate_x = gate0 + (gate1 - gate0) * x; - let val_x = val0 + (val1 - val0) * x; - ys[i] += chi_x * gate_x * val_x; + let chi_x = interp(chi0, chi1, x); + if chi_x == K::ZERO { + continue; + } + let gate_x = interp(gate0, gate1, x); + if gate_x == K::ZERO { + continue; + } + let cols_x: [K; N] = std::array::from_fn(|j| interp(cols0[j], cols1[j], x)); + let expr_x = (self.expr_fn)(&cols_x); + if expr_x == K::ZERO { + continue; + } + ys[i] += chi_x * gate_x * expr_x; } - } + }); ys } @@ -300,87 +296,183 @@ impl RoundOracle for ShoutValueOracleSparse { } self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); self.has_lookup.fold_round_in_place(r); - self.val.fold_round_in_place(r); - self.challenges.push(r); + for col in self.cols.iter_mut() { + col.fold_round_in_place(r); + } self.bit_idx += 1; } } -// ============================================================================ -// Packed-key RV32 ADD Shout (time-domain) -// ============================================================================ +fn expr_rv32_packed_add(cols: &[K; 4]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let carry = cols[2]; + let val = cols[3]; + lhs + rhs - val - carry * K::from_u64(1u64 << 32) +} + +fn expr_rv32_packed_sub(cols: &[K; 4]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let borrow = cols[2]; + let val = cols[3]; + lhs - rhs - val + borrow * K::from_u64(1u64 << 32) +} + +fn expr_rv32_packed_mulhsu_adapter(cols: &[K; 6]) -> K { + let rhs = cols[1]; + let lhs_sign = cols[2]; + let hi = cols[3]; + let borrow = cols[4]; + let val = cols[5]; + hi - lhs_sign * rhs - val + borrow * K::from_u64(1u64 << 32) +} + +fn expr_rv32_packed_slt(cols: &[K; 6]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let lhs_sign = cols[2]; + let rhs_sign = cols[3]; + let diff = cols[4]; + let out = cols[5]; + let two = K::from_u64(2); + let two31 = K::from_u64(1u64 << 31); + let two32 = K::from_u64(1u64 << 32); + let lhs_b = lhs + (K::ONE - two * lhs_sign) * two31; + let rhs_b = rhs + (K::ONE - two * rhs_sign) * two31; + lhs_b - rhs_b - diff + out * two32 +} + +fn expr_rv32_packed_sltu(cols: &[K; 4]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let diff = cols[2]; + let out = cols[3]; + lhs - rhs - diff + out * K::from_u64(1u64 << 32) +} + +fn expr_rv32_packed_divu(cols: &[K; 5]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let rem = cols[2]; + let z = cols[3]; + let quot = cols[4]; + let all_ones = K::from_u64(u32::MAX as u64); + z * (quot - all_ones) + (K::ONE - z) * (lhs - rhs * quot - rem) +} + +fn expr_rv32_packed_remu(cols: &[K; 5]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let quot = cols[2]; + let z = cols[3]; + let rem = cols[4]; + z * (rem - lhs) + (K::ONE - z) * (lhs - rhs * quot - rem) +} + +fn expr_rv32_packed_div(cols: &[K; 6]) -> K { + let lhs_sign = cols[0]; + let rhs_sign = cols[1]; + let z = cols[2]; + let q_abs = cols[3]; + let q_is_zero = cols[4]; + let val = cols[5]; + let two = K::from_u64(2); + let two32 = K::from_u64(1u64 << 32); + let all_ones = K::from_u64(u32::MAX as u64); + let div_sign = lhs_sign + rhs_sign - two * lhs_sign * rhs_sign; + let neg_q = (K::ONE - q_is_zero) * (two32 - q_abs); + let q_signed = (K::ONE - div_sign) * q_abs + div_sign * neg_q; + z * (val - all_ones) + (K::ONE - z) * (val - q_signed) +} + +fn expr_rv32_packed_rem(cols: &[K; 6]) -> K { + let lhs = cols[0]; + let lhs_sign = cols[1]; + let z = cols[2]; + let r_abs = cols[3]; + let r_is_zero = cols[4]; + let val = cols[5]; + let two32 = K::from_u64(1u64 << 32); + let neg_r = (K::ONE - r_is_zero) * (two32 - r_abs); + let r_signed = (K::ONE - lhs_sign) * r_abs + lhs_sign * neg_r; + z * (val - lhs) + (K::ONE - z) * (val - r_signed) +} + +fn expr_rv32_packed_mul(cols: &[K; 3], limb_sum: K) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let val = cols[2]; + lhs * rhs - val - limb_sum * K::from_u64(1u64 << 32) +} + +fn expr_rv32_packed_mulhu(cols: &[K; 3], limb_sum: K) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let val = cols[2]; + lhs * rhs - limb_sum - val * K::from_u64(1u64 << 32) +} + +type SparseWeightedBitsExprFn = fn(&[K; N], K) -> K; -/// Sparse Route A oracle for RV32 packed ADD correctness: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) + rhs(t) - val(t) - carry(t)·2^32) -/// -/// This is a "no width bloat" alternative to the 64-bit addr-bit Shout encoding for `ADD`: -/// instead of committing to 64 key bits, we commit to packed `lhs/rhs` plus a carry bit. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedAddOracleSparseTime { +struct SparseWeightedBitsExprOracle { bit_idx: usize, r_cycle: Vec, prefix_eq: K, has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - carry: SparseIdxVec, - val: SparseIdxVec, + cols: [SparseIdxVec; N], + bits: Vec>, + weights: Vec, degree_bound: usize, + expr_fn: SparseWeightedBitsExprFn, } -impl Rv32PackedAddOracleSparseTime { - pub fn new( +impl SparseWeightedBitsExprOracle { + fn new( r_cycle: &[K], has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - carry: SparseIdxVec, - val: SparseIdxVec, + cols: [SparseIdxVec; N], + bits: Vec>, + weights: Vec, + degree_bound: usize, + expr_fn: SparseWeightedBitsExprFn, ) -> Self { let ell_n = r_cycle.len(); debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(carry.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); + assert_cols_match_time(&cols, 1usize << ell_n); + debug_assert_eq!(bits.len(), weights.len()); + assert_cols_match_time(&bits, 1usize << ell_n); Self { bit_idx: 0, r_cycle: r_cycle.to_vec(), prefix_eq: K::ONE, has_lookup, - lhs, - rhs, - carry, - val, - degree_bound: 3, + cols, + bits, + weights, + degree_bound, + expr_fn, } } } -impl RoundOracle for Rv32PackedAddOracleSparseTime { +impl RoundOracle for SparseWeightedBitsExprOracle { fn evals_at(&mut self, points: &[K]) -> Vec { if self.has_lookup.len() == 1 { - let two32 = K::from_u64(1u64 << 32); let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let carry = self.carry.singleton_value(); - let val = self.val.singleton_value(); - let expr = lhs + rhs - val - carry * two32; + let cols = std::array::from_fn(|i| self.cols[i].singleton_value()); + let mut sum = K::ZERO; + for (b, w) in self.bits.iter().zip(self.weights.iter()) { + sum += b.singleton_value() * *w; + } + let expr = (self.expr_fn)(&cols, sum); let v = self.prefix_eq * gate * expr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let two32 = K::from_u64(1u64 << 32); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -390,17 +482,14 @@ impl RoundOracle for Rv32PackedAddOracleSparseTime { continue; } - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let carry0 = self.carry.get(child0); - let carry1 = self.carry.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - let expr0 = lhs0 + rhs0 - val0 - carry0 * two32; - let expr1 = lhs1 + rhs1 - val1 - carry1 * two32; + let cols0: [K; N] = std::array::from_fn(|i| self.cols[i].get(child0)); + let cols1: [K; N] = std::array::from_fn(|i| self.cols[i].get(child1)); + let mut sum0 = K::ZERO; + let mut sum1 = K::ZERO; + for (b, w) in self.bits.iter().zip(self.weights.iter()) { + sum0 += b.get(child0) * *w; + sum1 += b.get(child1) * *w; + } let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); @@ -413,13 +502,15 @@ impl RoundOracle for Rv32PackedAddOracleSparseTime { if gate_x == K::ZERO { continue; } - let expr_x = interp(expr0, expr1, x); + let cols_x: [K; N] = std::array::from_fn(|j| interp(cols0[j], cols1[j], x)); + let sum_x = interp(sum0, sum1, x); + let expr_x = (self.expr_fn)(&cols_x, sum_x); if expr_x == K::ZERO { continue; } ys[i] += chi_x * gate_x * expr_x; } - } + }); ys } @@ -437,85 +528,72 @@ impl RoundOracle for Rv32PackedAddOracleSparseTime { } self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.carry.fold_round_in_place(r); - self.val.fold_round_in_place(r); + for col in self.cols.iter_mut() { + col.fold_round_in_place(r); + } + for b in self.bits.iter_mut() { + b.fold_round_in_place(r); + } self.bit_idx += 1; } } -/// Sparse Route A oracle for RV32 packed SUB correctness: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - rhs(t) - val(t) + borrow(t)·2^32) -/// -/// This is a "no width bloat" alternative to the 64-bit addr-bit Shout encoding for `SUB`: -/// instead of committing to 64 key bits, we commit to packed `lhs/rhs` plus a borrow bit. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedSubOracleSparseTime { +type SparseValProdExprFn = fn(K, K) -> K; + +struct SparseValProdBitsOracle { bit_idx: usize, r_cycle: Vec, prefix_eq: K, has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - borrow: SparseIdxVec, val: SparseIdxVec, + bits: Vec>, degree_bound: usize, + expr_fn: SparseValProdExprFn, } -impl Rv32PackedSubOracleSparseTime { - pub fn new( +impl SparseValProdBitsOracle { + fn new( r_cycle: &[K], has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - borrow: SparseIdxVec, val: SparseIdxVec, + bits: Vec>, + degree_bound: usize, + expr_fn: SparseValProdExprFn, ) -> Self { let ell_n = r_cycle.len(); debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(borrow.len(), 1usize << ell_n); debug_assert_eq!(val.len(), 1usize << ell_n); + assert_cols_match_time(&bits, 1usize << ell_n); Self { bit_idx: 0, r_cycle: r_cycle.to_vec(), prefix_eq: K::ONE, has_lookup, - lhs, - rhs, - borrow, val, - degree_bound: 3, + bits, + degree_bound, + expr_fn, } } } -impl RoundOracle for Rv32PackedSubOracleSparseTime { +impl RoundOracle for SparseValProdBitsOracle { fn evals_at(&mut self, points: &[K]) -> Vec { if self.has_lookup.len() == 1 { - let two32 = K::from_u64(1u64 << 32); let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let borrow = self.borrow.singleton_value(); let val = self.val.singleton_value(); - let expr = lhs - rhs - val + borrow * two32; + let mut prod = K::ONE; + for b in self.bits.iter() { + prod *= K::ONE - b.singleton_value(); + } + let expr = (self.expr_fn)(val, prod); let v = self.prefix_eq * gate * expr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let two32 = K::from_u64(1u64 << 32); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -525,17 +603,13 @@ impl RoundOracle for Rv32PackedSubOracleSparseTime { continue; } - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let borrow0 = self.borrow.get(child0); - let borrow1 = self.borrow.get(child1); let val0 = self.val.get(child0); let val1 = self.val.get(child1); - - let expr0 = lhs0 - rhs0 - val0 + borrow0 * two32; - let expr1 = lhs1 - rhs1 - val1 + borrow1 * two32; + let bit_pairs: Vec<(K, K)> = self + .bits + .iter() + .map(|b| (b.get(child0), b.get(child1))) + .collect(); let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); @@ -548,13 +622,19 @@ impl RoundOracle for Rv32PackedSubOracleSparseTime { if gate_x == K::ZERO { continue; } - let expr_x = interp(expr0, expr1, x); + let val_x = interp(val0, val1, x); + let mut prod_x = K::ONE; + for &(b0, b1) in bit_pairs.iter() { + let bit_x = interp(b0, b1, x); + prod_x *= K::ONE - bit_x; + } + let expr_x = (self.expr_fn)(val_x, prod_x); if expr_x == K::ZERO { continue; } ys[i] += chi_x * gate_x * expr_x; } - } + }); ys } @@ -572,107 +652,119 @@ impl RoundOracle for Rv32PackedSubOracleSparseTime { } self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.borrow.fold_round_in_place(r); self.val.fold_round_in_place(r); + for b in self.bits.iter_mut() { + b.fold_round_in_place(r); + } self.bit_idx += 1; } } -// ============================================================================ -// Packed-key RV32 MUL Shout (time-domain) -// ============================================================================ - -/// Sparse Route A oracle for RV32 packed MUL correctness: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t)·rhs(t) - val(t) - carry(t)·2^32) -/// -/// Where: -/// - `carry(t)` is the high 32 bits of the 64-bit product `lhs·rhs`, encoded as 32 Boolean columns, -/// - `val(t)` is the low 32 bits (the `MUL` result). -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedMulOracleSparseTime { +struct SparseShiftRemBoundOracle { bit_idx: usize, r_cycle: Vec, prefix_eq: K, has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - carry_bits: Vec>, - val: SparseIdxVec, - degree_bound: usize, + shamt_bits: Vec>, + rem_bits: Vec>, } -impl Rv32PackedMulOracleSparseTime { - pub fn new( +fn shamt_values_singleton(shamt_bits: &[SparseIdxVec]) -> [K; 5] { + let mut shamt = [K::ZERO; 5]; + for (i, b) in shamt_bits.iter().enumerate() { + shamt[i] = b.singleton_value(); + } + shamt +} + +fn shamt_values_pair(shamt_bits: &[SparseIdxVec], child0: usize, child1: usize) -> ([K; 5], [K; 5]) { + let mut b0s = [K::ZERO; 5]; + let mut b1s = [K::ZERO; 5]; + for (i, b) in shamt_bits.iter().enumerate() { + b0s[i] = b.get(child0); + b1s[i] = b.get(child1); + } + (b0s, b1s) +} + +fn shamt_values_interp(b0s: &[K; 5], b1s: &[K; 5], x: K) -> [K; 5] { + let mut shamt = [K::ZERO; 5]; + for j in 0..5 { + shamt[j] = interp(b0s[j], b1s[j], x); + } + shamt +} + +fn pow2_from_shamt(shamt: &[K; 5]) -> K { + let pow2_const = [2u64, 4, 16, 256, 65536]; + let mut pow2 = K::ONE; + for (b, c) in shamt.iter().zip(pow2_const.iter()) { + let c = K::from_u64(*c); + pow2 *= K::ONE + *b * (c - K::ONE); + } + pow2 +} + +fn shift_rem_bound_expr(shamt: [K; 5], rem_bits_len: usize, mut rem_at: F) -> K +where + F: FnMut(usize) -> K, +{ + let mut tail_sum: [K; 32] = [K::ZERO; 32]; + let mut tail = K::ZERO; + for j in (0..rem_bits_len).rev() { + tail += rem_at(j) * K::from_u64(1u64 << j); + tail_sum[j] = tail; + } + + let mut expr = K::ZERO; + for s in 0..32usize { + let mut prod = K::ONE; + for j in 0..5usize { + let b = shamt[j]; + prod *= if ((s >> j) & 1) == 1 { b } else { K::ONE - b }; + } + expr += prod * tail_sum[s]; + } + expr +} + +impl SparseShiftRemBoundOracle { + fn new( r_cycle: &[K], has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - carry_bits: Vec>, - val: SparseIdxVec, + shamt_bits: Vec>, + rem_bits: Vec>, ) -> Self { let ell_n = r_cycle.len(); debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - debug_assert_eq!(carry_bits.len(), 32); - for (i, b) in carry_bits.iter().enumerate() { - debug_assert_eq!( - b.len(), - 1usize << ell_n, - "carry_bits[{i}] length must match time domain" - ); - } + debug_assert_eq!(shamt_bits.len(), 5); + assert_cols_match_time(&shamt_bits, 1usize << ell_n); + debug_assert!(rem_bits.len() <= 32); + assert_cols_match_time(&rem_bits, 1usize << ell_n); Self { bit_idx: 0, r_cycle: r_cycle.to_vec(), prefix_eq: K::ONE, has_lookup, - lhs, - rhs, - carry_bits, - val, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - lhs(t)·rhs(t): degree 2 - // - val(t), carry(t): multilinear (degree 1) - // ⇒ total degree ≤ 1 + 1 + 2 = 4 - degree_bound: 4, + shamt_bits, + rem_bits, } } } -impl RoundOracle for Rv32PackedMulOracleSparseTime { +impl RoundOracle for SparseShiftRemBoundOracle { fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); - if self.has_lookup.len() == 1 { let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let val = self.val.singleton_value(); - - let mut carry = K::ZERO; - for (i, b) in self.carry_bits.iter().enumerate() { - carry += b.singleton_value() * K::from_u64(1u64 << i); - } - - let expr = lhs * rhs - val - carry * two32; + let shamt = shamt_values_singleton(&self.shamt_bits); + let expr = shift_rem_bound_expr(shamt, self.rem_bits.len(), |j| self.rem_bits[j].singleton_value()); let v = self.prefix_eq * gate * expr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -682,19 +774,12 @@ impl RoundOracle for Rv32PackedMulOracleSparseTime { continue; } - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - // Pre-fetch carry bit endpoints for this pair to avoid repeated sparse lookups per eval point. - let mut c0s: [K; 32] = [K::ZERO; 32]; - let mut c1s: [K; 32] = [K::ZERO; 32]; - for (i, b) in self.carry_bits.iter().enumerate() { - c0s[i] = b.get(child0); - c1s[i] = b.get(child1); + let (b0s, b1s) = shamt_values_pair(&self.shamt_bits, child0, child1); + let mut r0s = Vec::with_capacity(self.rem_bits.len()); + let mut r1s = Vec::with_capacity(self.rem_bits.len()); + for b in self.rem_bits.iter() { + r0s.push(b.get(child0)); + r1s.push(b.get(child1)); } let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); @@ -709,24 +794,15 @@ impl RoundOracle for Rv32PackedMulOracleSparseTime { continue; } - let lhs_x = interp(lhs0, lhs1, x); - let rhs_x = interp(rhs0, rhs1, x); - let val_x = interp(val0, val1, x); - - let mut carry_x = K::ZERO; - for j in 0..32 { - let c_x = interp(c0s[j], c1s[j], x); - carry_x += c_x * K::from_u64(1u64 << j); - } + let shamt = shamt_values_interp(&b0s, &b1s, x); + let expr_x = shift_rem_bound_expr(shamt, self.rem_bits.len(), |j| interp(r0s[j], r1s[j], x)); - let expr_x = lhs_x * rhs_x - val_x - carry_x * two32; if expr_x == K::ZERO { continue; } - ys[i] += chi_x * gate_x * expr_x; } - } + }); ys } @@ -736,7 +812,7 @@ impl RoundOracle for Rv32PackedMulOracleSparseTime { } fn degree_bound(&self) -> usize { - self.degree_bound + 8 } fn fold(&mut self, r: K) { @@ -745,57 +821,53 @@ impl RoundOracle for Rv32PackedMulOracleSparseTime { } self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - for b in self.carry_bits.iter_mut() { + for b in self.shamt_bits.iter_mut() { + b.fold_round_in_place(r); + } + for b in self.rem_bits.iter_mut() { b.fold_round_in_place(r); } - self.val.fold_round_in_place(r); self.bit_idx += 1; } } -// ============================================================================ -// Packed-key RV32 MULHU Shout (time-domain) -// ============================================================================ +type SparseShiftExprFn = fn(lhs: K, val: K, pow2: K, limb_sum: K, sign: K) -> K; -/// Sparse Route A oracle for RV32 packed MULHU correctness (unsigned high 32 bits): -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t)·rhs(t) - lo(t) - val(t)·2^32) -/// -/// Where: -/// - `lo(t)` is the low 32 bits of the 64-bit product `lhs·rhs`, encoded as 32 Boolean columns, -/// - `val(t)` is the upper 32 bits (the `MULHU` result). -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedMulhuOracleSparseTime { +struct SparseShiftExprOracle { bit_idx: usize, r_cycle: Vec, prefix_eq: K, has_lookup: SparseIdxVec, lhs: SparseIdxVec, - rhs: SparseIdxVec, - lo_bits: Vec>, + shamt_bits: Vec>, + bits: Vec>, + sign: Option>, val: SparseIdxVec, degree_bound: usize, + expr_fn: SparseShiftExprFn, } -impl Rv32PackedMulhuOracleSparseTime { - pub fn new( +impl SparseShiftExprOracle { + fn new( r_cycle: &[K], has_lookup: SparseIdxVec, lhs: SparseIdxVec, - rhs: SparseIdxVec, - lo_bits: Vec>, + shamt_bits: Vec>, + bits: Vec>, + sign: Option>, val: SparseIdxVec, + degree_bound: usize, + expr_fn: SparseShiftExprFn, ) -> Self { let ell_n = r_cycle.len(); debug_assert_eq!(has_lookup.len(), 1usize << ell_n); debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); debug_assert_eq!(val.len(), 1usize << ell_n); - debug_assert_eq!(lo_bits.len(), 32); - for (i, b) in lo_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "lo_bits[{i}] length must match time domain"); + debug_assert_eq!(shamt_bits.len(), 5); + assert_cols_match_time(&shamt_bits, 1usize << ell_n); + assert_cols_match_time(&bits, 1usize << ell_n); + if let Some(s) = sign.as_ref() { + debug_assert_eq!(s.len(), 1usize << ell_n); } Self { @@ -804,46 +876,42 @@ impl Rv32PackedMulhuOracleSparseTime { prefix_eq: K::ONE, has_lookup, lhs, - rhs, - lo_bits, + shamt_bits, + bits, + sign, val, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - lhs(t)·rhs(t): degree 2 - // - lo(t), val(t): multilinear (degree 1) - // ⇒ total degree ≤ 1 + 1 + 2 = 4 - degree_bound: 4, + degree_bound, + expr_fn, } } } -impl RoundOracle for Rv32PackedMulhuOracleSparseTime { +impl RoundOracle for SparseShiftExprOracle { fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); - if self.has_lookup.len() == 1 { let gate = self.has_lookup.singleton_value(); let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); let val = self.val.singleton_value(); + let shamt = shamt_values_singleton(&self.shamt_bits); + let pow2 = pow2_from_shamt(&shamt); - let mut lo = K::ZERO; - for (i, b) in self.lo_bits.iter().enumerate() { - lo += b.singleton_value() * K::from_u64(1u64 << i); + let mut limb_sum = K::ZERO; + for (i, b) in self.bits.iter().enumerate() { + limb_sum += b.singleton_value() * K::from_u64(1u64 << i); } - let expr = lhs * rhs - lo - val * two32; + let sign = self + .sign + .as_ref() + .map(|s| s.singleton_value()) + .unwrap_or(K::ZERO); + let expr = (self.expr_fn)(lhs, val, pow2, limb_sum, sign); let v = self.prefix_eq * gate * expr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -855,17 +923,17 @@ impl RoundOracle for Rv32PackedMulhuOracleSparseTime { let lhs0 = self.lhs.get(child0); let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); let val0 = self.val.get(child0); let val1 = self.val.get(child1); + let sign0 = self.sign.as_ref().map(|s| s.get(child0)).unwrap_or(K::ZERO); + let sign1 = self.sign.as_ref().map(|s| s.get(child1)).unwrap_or(K::ZERO); - // Pre-fetch lo bit endpoints for this pair to avoid repeated sparse lookups per eval point. - let mut lo0s: [K; 32] = [K::ZERO; 32]; - let mut lo1s: [K; 32] = [K::ZERO; 32]; - for (i, b) in self.lo_bits.iter().enumerate() { - lo0s[i] = b.get(child0); - lo1s[i] = b.get(child1); + let (b0s, b1s) = shamt_values_pair(&self.shamt_bits, child0, child1); + let mut l0s = Vec::with_capacity(self.bits.len()); + let mut l1s = Vec::with_capacity(self.bits.len()); + for b in self.bits.iter() { + l0s.push(b.get(child0)); + l1s.push(b.get(child1)); } let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); @@ -881,24 +949,25 @@ impl RoundOracle for Rv32PackedMulhuOracleSparseTime { } let lhs_x = interp(lhs0, lhs1, x); - let rhs_x = interp(rhs0, rhs1, x); let val_x = interp(val0, val1, x); + let sign_x = interp(sign0, sign1, x); - let mut lo_x = K::ZERO; - for j in 0..32 { - let b_x = interp(lo0s[j], lo1s[j], x); - lo_x += b_x * K::from_u64(1u64 << j); + let shamt_x = shamt_values_interp(&b0s, &b1s, x); + let pow2_x = pow2_from_shamt(&shamt_x); + + let mut limb_sum_x = K::ZERO; + for j in 0..self.bits.len() { + let bit_x = interp(l0s[j], l1s[j], x); + limb_sum_x += bit_x * K::from_u64(1u64 << j); } - let expr_x = lhs_x * rhs_x - lo_x - val_x * two32; + let expr_x = (self.expr_fn)(lhs_x, val_x, pow2_x, limb_sum_x, sign_x); if expr_x == K::ZERO { continue; } - ys[i] += chi_x * gate_x * expr_x; } - } - + }); ys } @@ -917,103 +986,95 @@ impl RoundOracle for Rv32PackedMulhuOracleSparseTime { self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); self.has_lookup.fold_round_in_place(r); self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - for b in self.lo_bits.iter_mut() { + for b in self.shamt_bits.iter_mut() { + b.fold_round_in_place(r); + } + for b in self.bits.iter_mut() { b.fold_round_in_place(r); } + if let Some(s) = self.sign.as_mut() { + s.fold_round_in_place(r); + } self.val.fold_round_in_place(r); self.bit_idx += 1; } } -// ============================================================================ -// Packed-key RV32 MULH / MULHSU helpers (time-domain) -// ============================================================================ +fn expr_rv32_packed_sll(lhs: K, val: K, pow2: K, carry: K, _sign: K) -> K { + lhs * pow2 - val - carry * K::from_u64(1u64 << 32) +} -/// Sparse Route A oracle for RV32 unsigned product decomposition: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t)·rhs(t) - lo(t) - hi(t)·2^32) -/// -/// Where: -/// - `lo(t)` is the low 32 bits of the 64-bit product, encoded as 32 Boolean columns, -/// - `hi(t)` is a witness column intended to be the upper 32 bits. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedMulHiOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, +fn expr_rv32_packed_srl(lhs: K, val: K, pow2: K, rem: K, _sign: K) -> K { + lhs - val * pow2 - rem +} + +fn expr_rv32_packed_sra(lhs: K, val: K, pow2: K, rem: K, sign: K) -> K { + lhs - val * pow2 - rem - sign * K::from_u64(1u64 << 32) * (K::ONE - pow2) +} + +type SparseBitsAndWeightsExprFn = fn(&[K; N], K, &[K; W]) -> K; + +struct SparseBitsAndWeightsExprOracle { + bit_idx: usize, + r_cycle: Vec, prefix_eq: K, has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lo_bits: Vec>, - hi: SparseIdxVec, + cols: [SparseIdxVec; N], + bits: Vec>, + bit_weights: Vec, + expr_weights: [K; W], degree_bound: usize, + expr_fn: SparseBitsAndWeightsExprFn, } -impl Rv32PackedMulHiOracleSparseTime { - pub fn new( +impl SparseBitsAndWeightsExprOracle { + fn new( r_cycle: &[K], has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lo_bits: Vec>, - hi: SparseIdxVec, + cols: [SparseIdxVec; N], + bits: Vec>, + bit_weights: Vec, + expr_weights: [K; W], + degree_bound: usize, + expr_fn: SparseBitsAndWeightsExprFn, ) -> Self { let ell_n = r_cycle.len(); debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(hi.len(), 1usize << ell_n); - debug_assert_eq!(lo_bits.len(), 32); - for (i, b) in lo_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "lo_bits[{i}] length must match time domain"); - } + assert_cols_match_time(&cols, 1usize << ell_n); + debug_assert_eq!(bits.len(), bit_weights.len()); + assert_cols_match_time(&bits, 1usize << ell_n); Self { bit_idx: 0, r_cycle: r_cycle.to_vec(), prefix_eq: K::ONE, has_lookup, - lhs, - rhs, - lo_bits, - hi, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - lhs(t)·rhs(t): degree 2 - // ⇒ total degree ≤ 1 + 1 + 2 = 4 - degree_bound: 4, + cols, + bits, + bit_weights, + expr_weights, + degree_bound, + expr_fn, } } } -impl RoundOracle for Rv32PackedMulHiOracleSparseTime { +impl RoundOracle for SparseBitsAndWeightsExprOracle { fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); - if self.has_lookup.len() == 1 { let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let hi = self.hi.singleton_value(); - - let mut lo = K::ZERO; - for (i, b) in self.lo_bits.iter().enumerate() { - lo += b.singleton_value() * K::from_u64(1u64 << i); + let cols: [K; N] = std::array::from_fn(|i| self.cols[i].singleton_value()); + let mut bit_sum = K::ZERO; + for (b, w) in self.bits.iter().zip(self.bit_weights.iter()) { + bit_sum += b.singleton_value() * *w; } - - let expr = lhs * rhs - lo - hi * two32; + let expr = (self.expr_fn)(&cols, bit_sum, &self.expr_weights); let v = self.prefix_eq * gate * expr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -1023,19 +1084,13 @@ impl RoundOracle for Rv32PackedMulHiOracleSparseTime { continue; } - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let hi0 = self.hi.get(child0); - let hi1 = self.hi.get(child1); - - // Pre-fetch lo bit endpoints for this pair to avoid repeated sparse lookups per eval point. - let mut lo0s: [K; 32] = [K::ZERO; 32]; - let mut lo1s: [K; 32] = [K::ZERO; 32]; - for (i, b) in self.lo_bits.iter().enumerate() { - lo0s[i] = b.get(child0); - lo1s[i] = b.get(child1); + let cols0: [K; N] = std::array::from_fn(|i| self.cols[i].get(child0)); + let cols1: [K; N] = std::array::from_fn(|i| self.cols[i].get(child1)); + let mut bit_sum0 = K::ZERO; + let mut bit_sum1 = K::ZERO; + for (b, w) in self.bits.iter().zip(self.bit_weights.iter()) { + bit_sum0 += b.get(child0) * *w; + bit_sum1 += b.get(child1) * *w; } let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); @@ -1049,26 +1104,15 @@ impl RoundOracle for Rv32PackedMulHiOracleSparseTime { if gate_x == K::ZERO { continue; } - - let lhs_x = interp(lhs0, lhs1, x); - let rhs_x = interp(rhs0, rhs1, x); - let hi_x = interp(hi0, hi1, x); - - let mut lo_x = K::ZERO; - for j in 0..32 { - let b_x = interp(lo0s[j], lo1s[j], x); - lo_x += b_x * K::from_u64(1u64 << j); - } - - let expr_x = lhs_x * rhs_x - lo_x - hi_x * two32; + let cols_x: [K; N] = std::array::from_fn(|j| interp(cols0[j], cols1[j], x)); + let bit_sum_x = interp(bit_sum0, bit_sum1, x); + let expr_x = (self.expr_fn)(&cols_x, bit_sum_x, &self.expr_weights); if expr_x == K::ZERO { continue; } - ys[i] += chi_x * gate_x * expr_x; } - } - + }); ys } @@ -1086,117 +1130,76 @@ impl RoundOracle for Rv32PackedMulHiOracleSparseTime { } self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - for b in self.lo_bits.iter_mut() { + for col in self.cols.iter_mut() { + col.fold_round_in_place(r); + } + for b in self.bits.iter_mut() { b.fold_round_in_place(r); } - self.hi.fold_round_in_place(r); self.bit_idx += 1; } } -/// Sparse Route A oracle for RV32 packed MULH signed correction: -/// Σ_t χ(t)·has_lookup(t)·(w0·(hi - s1·rhs - s2·lhs + k·2^32 - val) + w1·k(k-1)(k-2)) -/// -/// Where: -/// - `hi` is the upper 32 bits of the unsigned product `lhs·rhs`, -/// - `s1`, `s2` are witness sign bits (msb of lhs/rhs), -/// - `k ∈ {0,1,2}` accounts for mod-2^32 normalization. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedMulhAdapterOracleSparseTime { +type SparseColsBitsExprFn = fn(&[K; N], &[K], &[K; W]) -> K; + +struct SparseColsBitsExprOracle { bit_idx: usize, r_cycle: Vec, prefix_eq: K, has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_sign: SparseIdxVec, - hi: SparseIdxVec, - k: SparseIdxVec, - val: SparseIdxVec, - weights: [K; 2], + cols: [SparseIdxVec; N], + bits: Vec>, + expr_weights: [K; W], degree_bound: usize, + expr_fn: SparseColsBitsExprFn, } -impl Rv32PackedMulhAdapterOracleSparseTime { - pub fn new( +impl SparseColsBitsExprOracle { + fn new( r_cycle: &[K], has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_sign: SparseIdxVec, - hi: SparseIdxVec, - k: SparseIdxVec, - val: SparseIdxVec, - weights: [K; 2], + cols: [SparseIdxVec; N], + bits: Vec>, + expr_weights: [K; W], + degree_bound: usize, + expr_fn: SparseColsBitsExprFn, ) -> Self { let ell_n = r_cycle.len(); debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(rhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(hi.len(), 1usize << ell_n); - debug_assert_eq!(k.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); + assert_cols_match_time(&cols, 1usize << ell_n); + assert_cols_match_time(&bits, 1usize << ell_n); Self { bit_idx: 0, r_cycle: r_cycle.to_vec(), prefix_eq: K::ONE, has_lookup, - lhs, - rhs, - lhs_sign, - rhs_sign, - hi, - k, - val, - weights, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - eq expr: degree 2 (sign·rhs) - // - range poly: degree 3 - // ⇒ total degree ≤ 1 + 1 + 3 = 5 - degree_bound: 5, + cols, + bits, + expr_weights, + degree_bound, + expr_fn, } } } -impl RoundOracle for Rv32PackedMulhAdapterOracleSparseTime { +impl RoundOracle for SparseColsBitsExprOracle { fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); - let w0 = self.weights[0]; - let w1 = self.weights[1]; - if self.has_lookup.len() == 1 { let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let lhs_sign = self.lhs_sign.singleton_value(); - let rhs_sign = self.rhs_sign.singleton_value(); - let hi = self.hi.singleton_value(); - let k = self.k.singleton_value(); - let val = self.val.singleton_value(); - - let eq_expr = hi - lhs_sign * rhs - rhs_sign * lhs + k * two32 - val; - let range = k * (k - K::ONE) * (k - K::from_u64(2)); - let expr = w0 * eq_expr + w1 * range; + let cols: [K; N] = std::array::from_fn(|i| self.cols[i].singleton_value()); + let bits: Vec = self + .bits + .iter() + .map(SparseIdxVec::singleton_value) + .collect(); + let expr = (self.expr_fn)(&cols, &bits, &self.expr_weights); let v = self.prefix_eq * gate * expr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -1206,20 +1209,16 @@ impl RoundOracle for Rv32PackedMulhAdapterOracleSparseTime { continue; } - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let lhs_sign0 = self.lhs_sign.get(child0); - let lhs_sign1 = self.lhs_sign.get(child1); - let rhs_sign0 = self.rhs_sign.get(child0); - let rhs_sign1 = self.rhs_sign.get(child1); - let hi0 = self.hi.get(child0); - let hi1 = self.hi.get(child1); - let k0 = self.k.get(child0); - let k1 = self.k.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); + let cols0: [K; N] = std::array::from_fn(|i| self.cols[i].get(child0)); + let cols1: [K; N] = std::array::from_fn(|i| self.cols[i].get(child1)); + + let mut bits0 = Vec::with_capacity(self.bits.len()); + let mut bits1 = Vec::with_capacity(self.bits.len()); + for b in self.bits.iter() { + bits0.push(b.get(child0)); + bits1.push(b.get(child1)); + } + let mut bits_x = vec![K::ZERO; self.bits.len()]; let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); @@ -1233,25 +1232,18 @@ impl RoundOracle for Rv32PackedMulhAdapterOracleSparseTime { continue; } - let lhs_x = interp(lhs0, lhs1, x); - let rhs_x = interp(rhs0, rhs1, x); - let lhs_sign_x = interp(lhs_sign0, lhs_sign1, x); - let rhs_sign_x = interp(rhs_sign0, rhs_sign1, x); - let hi_x = interp(hi0, hi1, x); - let k_x = interp(k0, k1, x); - let val_x = interp(val0, val1, x); + let cols_x: [K; N] = std::array::from_fn(|j| interp(cols0[j], cols1[j], x)); + for j in 0..bits_x.len() { + bits_x[j] = interp(bits0[j], bits1[j], x); + } - let eq_expr = hi_x - lhs_sign_x * rhs_x - rhs_sign_x * lhs_x + k_x * two32 - val_x; - let range = k_x * (k_x - K::ONE) * (k_x - K::from_u64(2)); - let expr_x = w0 * eq_expr + w1 * range; + let expr_x = (self.expr_fn)(&cols_x, &bits_x, &self.expr_weights); if expr_x == K::ZERO { continue; } - ys[i] += chi_x * gate_x * expr_x; } - } - + }); ys } @@ -1269,3939 +1261,717 @@ impl RoundOracle for Rv32PackedMulhAdapterOracleSparseTime { } self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.lhs_sign.fold_round_in_place(r); - self.rhs_sign.fold_round_in_place(r); - self.hi.fold_round_in_place(r); - self.k.fold_round_in_place(r); - self.val.fold_round_in_place(r); + for col in self.cols.iter_mut() { + col.fold_round_in_place(r); + } + for b in self.bits.iter_mut() { + b.fold_round_in_place(r); + } self.bit_idx += 1; } } -/// Sparse Route A oracle for RV32 packed MULHSU signed correction: -/// Σ_t χ(t)·has_lookup(t)·(hi - s·rhs - val + b·2^32) -/// -/// Where: -/// - `hi` is the upper 32 bits of the unsigned product `lhs·rhs`, -/// - `s` is the witness sign bit of `lhs`, -/// - `b ∈ {0,1}` is a borrow bit for mod-2^32 normalization. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedMulhsuAdapterOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lhs_sign: SparseIdxVec, - hi: SparseIdxVec, - borrow: SparseIdxVec, - val: SparseIdxVec, - degree_bound: usize, +fn expr_rv32_packed_mulh_adapter(cols: &[K; 7], _bit_sum: K, w: &[K; 2]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let lhs_sign = cols[2]; + let rhs_sign = cols[3]; + let hi = cols[4]; + let k = cols[5]; + let val = cols[6]; + let eq_expr = hi - lhs_sign * rhs - rhs_sign * lhs + k * K::from_u64(1u64 << 32) - val; + let range = k * (k - K::ONE) * (k - K::from_u64(2)); + w[0] * eq_expr + w[1] * range +} + +fn expr_rv32_packed_divremu_adapter(cols: &[K; 4], bit_sum: K, w: &[K; 4]) -> K { + let rhs = cols[0]; + let z = cols[1]; + let rem = cols[2]; + let diff = cols[3]; + let c0 = z * (K::ONE - z); + let c1 = z * rhs; + let c2 = (K::ONE - z) * (rem - rhs - diff + K::from_u64(1u64 << 32)); + let c3 = diff - bit_sum; + w[0] * c0 + w[1] * c1 + w[2] * c2 + w[3] * c3 +} + +fn expr_rv32_packed_divrem_adapter(cols: &[K; 10], bit_sum: K, w: &[K; 7]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let z = cols[2]; + let lhs_sign = cols[3]; + let rhs_sign = cols[4]; + let q_abs = cols[5]; + let r_abs = cols[6]; + let mag = cols[7]; + let mag_z = cols[8]; + let diff = cols[9]; + let two = K::from_u64(2); + let two32 = K::from_u64(1u64 << 32); + let lhs_abs = lhs + lhs_sign * (two32 - two * lhs); + let rhs_abs = rhs + rhs_sign * (two32 - two * rhs); + let c0 = z * (K::ONE - z); + let c1 = z * rhs; + let c2 = mag_z * (K::ONE - mag_z); + let c3 = mag_z * mag; + let c4 = (K::ONE - z) * (lhs_abs - rhs_abs * q_abs - r_abs); + let c5 = (K::ONE - z) * (r_abs - rhs_abs - diff + two32); + let c6 = diff - bit_sum; + w[0] * c0 + w[1] * c1 + w[2] * c2 + w[3] * c3 + w[4] * c4 + w[5] * c5 + w[6] * c6 +} + +fn expr_rv32_packed_eq_adapter(cols: &[K; 3], diff: K) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + let borrow = cols[2]; + lhs - rhs - diff + borrow * K::from_u64(1u64 << 32) +} + +fn expr_u_decomp(cols: &[K; 1], sum: K) -> K { + cols[0] - sum +} + +fn expr_eq_from_prod(val: K, prod: K) -> K { + val - prod +} + +fn expr_neq_from_prod(val: K, prod: K) -> K { + val + prod - K::ONE +} + +macro_rules! define_sparse_time_expr_oracle { + ($name:ident, $n:expr, $degree:expr, $expr:expr, [$($col:ident),+ $(,)?]) => { + pub struct $name { + core: SparseTimeExprOracle<$n>, + } + + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + $($col: SparseIdxVec,)+ + ) -> Self { + Self { + core: SparseTimeExprOracle::new( + r_cycle, + has_lookup, + [$($col),+], + $degree, + $expr, + ), + } + } + } + + impl_round_oracle_via_core!($name); + }; +} + +define_sparse_time_expr_oracle!( + Rv32PackedAddOracleSparseTime, + 4, + 3, + expr_rv32_packed_add, + [lhs, rhs, carry, val] +); + +define_sparse_time_expr_oracle!( + Rv32PackedSubOracleSparseTime, + 4, + 3, + expr_rv32_packed_sub, + [lhs, rhs, borrow, val] +); + +macro_rules! define_weighted_bits32_oracle3 { + ($name:ident, [$c0:ident, $c1:ident, $c2:ident], $expr:expr, $degree:expr) => { + pub struct $name { + core: SparseWeightedBitsExprOracle<3>, + } + + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + $c0: SparseIdxVec, + $c1: SparseIdxVec, + $c2: SparseIdxVec, + bits: Vec>, + ) -> Self { + debug_assert_eq!(bits.len(), 32); + let weights: Vec = (0..32).map(|i| K::from_u64(1u64 << i)).collect(); + Self { + core: SparseWeightedBitsExprOracle::new( + r_cycle, + has_lookup, + [$c0, $c1, $c2], + bits, + weights, + $degree, + $expr, + ), + } + } + } + impl_round_oracle_via_core!($name); + }; +} + +macro_rules! define_weighted_bits32_oracle3_bits_before_last { + ($name:ident, [$c0:ident, $c1:ident], $bits:ident, $c2:ident, $expr:expr, $degree:expr) => { + pub struct $name { + core: SparseWeightedBitsExprOracle<3>, + } + + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + $c0: SparseIdxVec, + $c1: SparseIdxVec, + $bits: Vec>, + $c2: SparseIdxVec, + ) -> Self { + debug_assert_eq!($bits.len(), 32); + let weights: Vec = (0..32).map(|i| K::from_u64(1u64 << i)).collect(); + Self { + core: SparseWeightedBitsExprOracle::new( + r_cycle, + has_lookup, + [$c0, $c1, $c2], + $bits, + weights, + $degree, + $expr, + ), + } + } + } + impl_round_oracle_via_core!($name); + }; +} + +define_weighted_bits32_oracle3_bits_before_last!( + Rv32PackedMulOracleSparseTime, + [lhs, rhs], + carry_bits, + val, + expr_rv32_packed_mul, + 4 +); + +define_weighted_bits32_oracle3_bits_before_last!( + Rv32PackedMulhuOracleSparseTime, + [lhs, rhs], + lo_bits, + val, + expr_rv32_packed_mulhu, + 4 +); + +define_weighted_bits32_oracle3_bits_before_last!( + Rv32PackedMulHiOracleSparseTime, + [lhs, rhs], + lo_bits, + hi, + expr_rv32_packed_mulhu, + 4 +); + +pub struct Rv32PackedMulhAdapterOracleSparseTime { + core: SparseBitsAndWeightsExprOracle<7, 2>, } -impl Rv32PackedMulhsuAdapterOracleSparseTime { +impl Rv32PackedMulhAdapterOracleSparseTime { pub fn new( r_cycle: &[K], has_lookup: SparseIdxVec, lhs: SparseIdxVec, rhs: SparseIdxVec, lhs_sign: SparseIdxVec, + rhs_sign: SparseIdxVec, hi: SparseIdxVec, - borrow: SparseIdxVec, + k: SparseIdxVec, val: SparseIdxVec, + weights: [K; 2], ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(hi.len(), 1usize << ell_n); - debug_assert_eq!(borrow.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - rhs, - lhs_sign, - hi, - borrow, - val, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - eq expr: degree 2 (sign·rhs) - // ⇒ total degree ≤ 1 + 1 + 2 = 4 - degree_bound: 4, + core: SparseBitsAndWeightsExprOracle::new( + r_cycle, + has_lookup, + [lhs, rhs, lhs_sign, rhs_sign, hi, k, val], + Vec::new(), + Vec::new(), + weights, + 5, + expr_rv32_packed_mulh_adapter, + ), } } } +impl_round_oracle_via_core!(Rv32PackedMulhAdapterOracleSparseTime); -impl RoundOracle for Rv32PackedMulhsuAdapterOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); +define_sparse_time_expr_oracle!( + Rv32PackedMulhsuAdapterOracleSparseTime, + 6, + 4, + expr_rv32_packed_mulhsu_adapter, + [lhs, rhs, lhs_sign, hi, borrow, val] +); - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let rhs = self.rhs.singleton_value(); - let lhs_sign = self.lhs_sign.singleton_value(); - let hi = self.hi.singleton_value(); - let borrow = self.borrow.singleton_value(); - let val = self.val.singleton_value(); - let expr = hi - lhs_sign * rhs - val + borrow * two32; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; +macro_rules! define_val_prod_bits_oracle { + ($name:ident, $expr:expr) => { + pub struct $name { + core: SparseValProdBitsOracle, } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + diff_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + debug_assert_eq!(diff_bits.len(), 32); + Self { + core: SparseValProdBitsOracle::new(r_cycle, has_lookup, val, diff_bits, 34, $expr), + } } + } + impl_round_oracle_via_core!($name); + }; +} - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let lhs_sign0 = self.lhs_sign.get(child0); - let lhs_sign1 = self.lhs_sign.get(child1); - let hi0 = self.hi.get(child0); - let hi1 = self.hi.get(child1); - let borrow0 = self.borrow.get(child0); - let borrow1 = self.borrow.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); +define_val_prod_bits_oracle!(Rv32PackedEqOracleSparseTime, expr_eq_from_prod); + +define_weighted_bits32_oracle3!( + Rv32PackedEqAdapterOracleSparseTime, + [lhs, rhs, borrow], + expr_rv32_packed_eq_adapter, + 3 +); + +define_val_prod_bits_oracle!(Rv32PackedNeqOracleSparseTime, expr_neq_from_prod); + +define_weighted_bits32_oracle3!( + Rv32PackedNeqAdapterOracleSparseTime, + [lhs, rhs, borrow], + expr_rv32_packed_eq_adapter, + 3 +); + +define_sparse_time_expr_oracle!( + Rv32PackedSltOracleSparseTime, + 6, + 3, + expr_rv32_packed_slt, + [lhs, rhs, lhs_sign, rhs_sign, diff, out] +); + +define_sparse_time_expr_oracle!( + Rv32PackedSltuOracleSparseTime, + 4, + 3, + expr_rv32_packed_sltu, + [lhs, rhs, diff, out] +); + +macro_rules! define_shift_oracle_no_sign { + ($name:ident, $bits_name:ident, $bits_len:expr, $expr:expr) => { + pub struct $name { + core: SparseShiftExprOracle, + } + + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + shamt_bits: Vec>, + $bits_name: Vec>, + val: SparseIdxVec, + ) -> Self { + debug_assert_eq!($bits_name.len(), $bits_len); + Self { + core: SparseShiftExprOracle::new( + r_cycle, has_lookup, lhs, shamt_bits, $bits_name, None, val, 8, $expr, + ), + } + } + } + impl_round_oracle_via_core!($name); + }; +} - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } +macro_rules! define_shift_oracle_with_sign { + ($name:ident, $bits_len:expr, $expr:expr) => { + pub struct $name { + core: SparseShiftExprOracle, + } + + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + shamt_bits: Vec>, + sign: SparseIdxVec, + rem_bits: Vec>, + val: SparseIdxVec, + ) -> Self { + debug_assert_eq!(rem_bits.len(), $bits_len); + Self { + core: SparseShiftExprOracle::new( + r_cycle, + has_lookup, + lhs, + shamt_bits, + rem_bits, + Some(sign), + val, + 8, + $expr, + ), + } + } + } + impl_round_oracle_via_core!($name); + }; +} - let rhs_x = interp(rhs0, rhs1, x); - let lhs_sign_x = interp(lhs_sign0, lhs_sign1, x); - let hi_x = interp(hi0, hi1, x); - let borrow_x = interp(borrow0, borrow1, x); - let val_x = interp(val0, val1, x); +macro_rules! define_shift_rem_bound_oracle { + ($name:ident, $bits_len:expr) => { + pub struct $name { + core: SparseShiftRemBoundOracle, + } - let expr_x = hi_x - lhs_sign_x * rhs_x - val_x + borrow_x * two32; - if expr_x == K::ZERO { - continue; + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + shamt_bits: Vec>, + rem_bits: Vec>, + ) -> Self { + debug_assert_eq!(rem_bits.len(), $bits_len); + Self { + core: SparseShiftRemBoundOracle::new(r_cycle, has_lookup, shamt_bits, rem_bits), } - - ys[i] += chi_x * gate_x * expr_x; } } + impl_round_oracle_via_core!($name); + }; +} - ys - } +define_shift_oracle_no_sign!(Rv32PackedSllOracleSparseTime, carry_bits, 32, expr_rv32_packed_sll); - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } +define_shift_oracle_no_sign!(Rv32PackedSrlOracleSparseTime, rem_bits, 32, expr_rv32_packed_srl); - fn degree_bound(&self) -> usize { - self.degree_bound - } +define_shift_rem_bound_oracle!(Rv32PackedSrlAdapterOracleSparseTime, 32); - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.lhs_sign.fold_round_in_place(r); - self.hi.fold_round_in_place(r); - self.borrow.fold_round_in_place(r); - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} +define_shift_oracle_with_sign!(Rv32PackedSraOracleSparseTime, 31, expr_rv32_packed_sra); -/// Sparse Route A oracle for RV32 packed EQ correctness (Ajtai-representable): -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (val(t) - Π_i (1 - diff_bit_i(t))) -/// -/// Where `diff_bits` encode `diff = lhs - rhs (mod 2^32)` (checked by the adapter oracle). -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedEqOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - diff_bits: Vec>, - val: SparseIdxVec, - degree_bound: usize, -} +define_shift_rem_bound_oracle!(Rv32PackedSraAdapterOracleSparseTime, 31); -impl Rv32PackedEqOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - diff_bits: Vec>, - val: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - debug_assert_eq!(diff_bits.len(), 32); - for (i, b) in diff_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); - } +define_sparse_time_expr_oracle!( + Rv32PackedDivuOracleSparseTime, + 5, + 5, + expr_rv32_packed_divu, + [lhs, rhs, rem, rhs_is_zero, quot] +); - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - diff_bits, - val, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - Π_i (1 - diff_bit_i(t)): product of 32 linear terms (degree 32) - // ⇒ total degree ≤ 1 + 1 + 32 = 34 - degree_bound: 34, +define_sparse_time_expr_oracle!( + Rv32PackedRemuOracleSparseTime, + 5, + 5, + expr_rv32_packed_remu, + [lhs, rhs, quot, rhs_is_zero, rem] +); + +macro_rules! define_bits_and_weights32_oracle { + ($name:ident, $n:expr, $m:expr, [$($col:ident),+ $(,)?], $degree:expr, $expr:expr) => { + pub struct $name { + core: SparseBitsAndWeightsExprOracle<$n, $m>, } - } -} -impl RoundOracle for Rv32PackedEqOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let val = self.val.singleton_value(); - let mut prod = K::ONE; - for b in self.diff_bits.iter() { - let bit = b.singleton_value(); - prod *= K::ONE - bit; + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + $($col: SparseIdxVec,)+ + diff_bits: Vec>, + weights: [K; $m], + ) -> Self { + debug_assert_eq!(diff_bits.len(), 32); + let bit_weights: Vec = (0..32).map(|i| K::from_u64(1u64 << i)).collect(); + Self { + core: SparseBitsAndWeightsExprOracle::new( + r_cycle, + has_lookup, + [$($col),+], + diff_bits, + bit_weights, + weights, + $degree, + $expr, + ), + } } - let expr = val - prod; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; } + impl_round_oracle_via_core!($name); + }; +} - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); +define_bits_and_weights32_oracle!( + Rv32PackedDivRemuAdapterOracleSparseTime, + 4, + 4, + [rhs, rhs_is_zero, rem, diff], + 4, + expr_rv32_packed_divremu_adapter +); + +define_sparse_time_expr_oracle!( + Rv32PackedDivOracleSparseTime, + 6, + 7, + expr_rv32_packed_div, + [lhs_sign, rhs_sign, rhs_is_zero, q_abs, q_is_zero, val] +); + +define_sparse_time_expr_oracle!( + Rv32PackedRemOracleSparseTime, + 6, + 7, + expr_rv32_packed_rem, + [lhs, lhs_sign, rhs_is_zero, r_abs, r_is_zero, val] +); + +define_bits_and_weights32_oracle!( + Rv32PackedDivRemAdapterOracleSparseTime, + 10, + 7, + [ + lhs, + rhs, + rhs_is_zero, + lhs_sign, + rhs_sign, + q_abs, + r_abs, + mag, + mag_is_zero, + diff + ], + 6, + expr_rv32_packed_divrem_adapter +); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Rv32PackedBitwiseOp2 { + And, + Andn, + Or, + Xor, +} - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let val_x = interp(val0, val1, x); - let mut prod_x = K::ONE; - for b in self.diff_bits.iter() { - let b0 = b.get(child0); - let b1 = b.get(child1); - let bit_x = interp(b0, b1, x); - prod_x *= K::ONE - bit_x; - } - let expr_x = val_x - prod_x; - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - for b in self.diff_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed EQ/NEQ diff decomposition check: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - rhs(t) - diff(t) + borrow(t)·2^32) -/// -/// Where: -/// - `borrow(t)` is the SUB borrow bit (1 iff lhs < rhs), -/// - `diff(t) = Σ_i 2^i · diff_bit_i(t)` is the u32 wraparound difference `lhs - rhs (mod 2^32)`. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedEqAdapterOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - borrow: SparseIdxVec, - diff_bits: Vec>, - degree_bound: usize, -} - -impl Rv32PackedEqAdapterOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - borrow: SparseIdxVec, - diff_bits: Vec>, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(borrow.len(), 1usize << ell_n); - debug_assert_eq!(diff_bits.len(), 32); - for (i, b) in diff_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - rhs, - borrow, - diff_bits, - degree_bound: 3, - } - } -} - -impl RoundOracle for Rv32PackedEqAdapterOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let two32 = K::from_u64(1u64 << 32); - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let borrow = self.borrow.singleton_value(); - let mut diff = K::ZERO; - for (i, b) in self.diff_bits.iter().enumerate() { - let bit = b.singleton_value(); - diff += bit * K::from_u64(1u64 << i); - } - let expr = lhs - rhs - diff + borrow * two32; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let two32 = K::from_u64(1u64 << 32); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let borrow0 = self.borrow.get(child0); - let borrow1 = self.borrow.get(child1); - let mut diff0 = K::ZERO; - let mut diff1 = K::ZERO; - for (i, b) in self.diff_bits.iter().enumerate() { - let pow = K::from_u64(1u64 << i); - diff0 += b.get(child0) * pow; - diff1 += b.get(child1) * pow; - } - let expr0 = lhs0 - rhs0 - diff0 + borrow0 * two32; - let expr1 = lhs1 - rhs1 - diff1 + borrow1 * two32; - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let expr_x = interp(expr0, expr1, x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.borrow.fold_round_in_place(r); - for b in self.diff_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed NEQ correctness (Ajtai-representable): -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (val(t) - (1 - Π_i (1 - diff_bit_i(t)))) -/// -/// Where `diff_bits` encode `diff = lhs - rhs (mod 2^32)` (checked by the adapter oracle). -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedNeqOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - diff_bits: Vec>, - val: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedNeqOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - diff_bits: Vec>, - val: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - debug_assert_eq!(diff_bits.len(), 32); - for (i, b) in diff_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - diff_bits, - val, - // Same degree bound as EQ: 1 + 1 + 32 = 34. - degree_bound: 34, - } - } -} - -impl RoundOracle for Rv32PackedNeqOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let val = self.val.singleton_value(); - let mut prod = K::ONE; - for b in self.diff_bits.iter() { - let bit = b.singleton_value(); - prod *= K::ONE - bit; - } - let expr = val + prod - K::ONE; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let val_x = interp(val0, val1, x); - let mut prod_x = K::ONE; - for b in self.diff_bits.iter() { - let b0 = b.get(child0); - let b1 = b.get(child1); - let bit_x = interp(b0, b1, x); - prod_x *= K::ONE - bit_x; - } - let expr_x = val_x + prod_x - K::ONE; - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - for b in self.diff_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed NEQ diff decomposition check: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - rhs(t) - diff(t) + borrow(t)·2^32) -/// -/// See [`Rv32PackedEqAdapterOracleSparseTime`] for details; this is identical but kept as a -/// distinct type for call-site clarity. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedNeqAdapterOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - borrow: SparseIdxVec, - diff_bits: Vec>, - degree_bound: usize, -} - -impl Rv32PackedNeqAdapterOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - borrow: SparseIdxVec, - diff_bits: Vec>, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(borrow.len(), 1usize << ell_n); - debug_assert_eq!(diff_bits.len(), 32); - for (i, b) in diff_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - rhs, - borrow, - diff_bits, - degree_bound: 3, - } - } -} - -impl RoundOracle for Rv32PackedNeqAdapterOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let two32 = K::from_u64(1u64 << 32); - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let borrow = self.borrow.singleton_value(); - let mut diff = K::ZERO; - for (i, b) in self.diff_bits.iter().enumerate() { - let bit = b.singleton_value(); - diff += bit * K::from_u64(1u64 << i); - } - let expr = lhs - rhs - diff + borrow * two32; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let two32 = K::from_u64(1u64 << 32); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let borrow0 = self.borrow.get(child0); - let borrow1 = self.borrow.get(child1); - let mut diff0 = K::ZERO; - let mut diff1 = K::ZERO; - for (i, b) in self.diff_bits.iter().enumerate() { - let pow = K::from_u64(1u64 << i); - diff0 += b.get(child0) * pow; - diff1 += b.get(child1) * pow; - } - let expr0 = lhs0 - rhs0 - diff0 + borrow0 * two32; - let expr1 = lhs1 - rhs1 - diff1 + borrow1 * two32; - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let expr_x = interp(expr0, expr1, x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.borrow.fold_round_in_place(r); - for b in self.diff_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -// ============================================================================ -// Packed-key RV32 SLTU Shout (time-domain) -// ============================================================================ - -/// Sparse Route A oracle for RV32 packed SLT (signed less-than) correctness. -/// -/// We reduce signed comparison to unsigned comparison by XOR-biasing both operands with `2^31` -/// (flip the sign bit). Let: -/// lhs_b = lhs ⊕ 2^31, rhs_b = rhs ⊕ 2^31 -/// then `(lhs as i32) < (rhs as i32)` iff `lhs_b < rhs_b` as unsigned. -/// -/// With witness bits `lhs_sign`, `rhs_sign` (intended as the msb of lhs/rhs), we implement the -/// XOR-biasing arithmetically: -/// x ⊕ 2^31 = x + (1 - 2·msb(x))·2^31. -/// -/// The correctness constraint is: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs_b(t) - rhs_b(t) - diff(t) + out(t)·2^32) -/// -/// Where `out(t)` is the SLT result bit (1 iff lhs < rhs in signed order, else 0) and `diff(t)` -/// is the u32 difference `lhs_b - rhs_b (mod 2^32)`. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedSltOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_sign: SparseIdxVec, - diff: SparseIdxVec, - out: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedSltOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_sign: SparseIdxVec, - diff: SparseIdxVec, - out: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(rhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(diff.len(), 1usize << ell_n); - debug_assert_eq!(out.len(), 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - rhs, - lhs_sign, - rhs_sign, - diff, - out, - degree_bound: 3, - } - } -} - -impl RoundOracle for Rv32PackedSltOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let two31 = K::from_u64(1u64 << 31); - let two32 = K::from_u64(1u64 << 32); - let two = K::from_u64(2); - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let lhs_sign = self.lhs_sign.singleton_value(); - let rhs_sign = self.rhs_sign.singleton_value(); - let diff = self.diff.singleton_value(); - let out = self.out.singleton_value(); - - let lhs_b = lhs + (K::ONE - two * lhs_sign) * two31; - let rhs_b = rhs + (K::ONE - two * rhs_sign) * two31; - let expr = lhs_b - rhs_b - diff + out * two32; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let lhs_sign0 = self.lhs_sign.get(child0); - let lhs_sign1 = self.lhs_sign.get(child1); - let rhs_sign0 = self.rhs_sign.get(child0); - let rhs_sign1 = self.rhs_sign.get(child1); - let diff0 = self.diff.get(child0); - let diff1 = self.diff.get(child1); - let out0 = self.out.get(child0); - let out1 = self.out.get(child1); - - let lhs_b0 = lhs0 + (K::ONE - two * lhs_sign0) * two31; - let lhs_b1 = lhs1 + (K::ONE - two * lhs_sign1) * two31; - let rhs_b0 = rhs0 + (K::ONE - two * rhs_sign0) * two31; - let rhs_b1 = rhs1 + (K::ONE - two * rhs_sign1) * two31; - - let expr0 = lhs_b0 - rhs_b0 - diff0 + out0 * two32; - let expr1 = lhs_b1 - rhs_b1 - diff1 + out1 * two32; - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let expr_x = interp(expr0, expr1, x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.lhs_sign.fold_round_in_place(r); - self.rhs_sign.fold_round_in_place(r); - self.diff.fold_round_in_place(r); - self.out.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed SLTU correctness: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - rhs(t) - diff(t) + out(t)·2^32) -/// -/// Where `out(t)` is the SLTU result bit (1 iff lhs < rhs, else 0) and `diff(t)` is the u32 -/// difference `lhs - rhs (mod 2^32)`. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedSltuOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - diff: SparseIdxVec, - out: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedSltuOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - diff: SparseIdxVec, - out: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(diff.len(), 1usize << ell_n); - debug_assert_eq!(out.len(), 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - rhs, - diff, - out, - degree_bound: 3, - } - } -} - -impl RoundOracle for Rv32PackedSltuOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let two32 = K::from_u64(1u64 << 32); - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let diff = self.diff.singleton_value(); - let out = self.out.singleton_value(); - let expr = lhs - rhs - diff + out * two32; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let two32 = K::from_u64(1u64 << 32); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let diff0 = self.diff.get(child0); - let diff1 = self.diff.get(child1); - let out0 = self.out.get(child0); - let out1 = self.out.get(child1); - - let expr0 = lhs0 - rhs0 - diff0 + out0 * two32; - let expr1 = lhs1 - rhs1 - diff1 + out1 * two32; - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let expr_x = interp(expr0, expr1, x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.diff.fold_round_in_place(r); - self.out.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -// ============================================================================ -// Packed-key RV32 SLL Shout (time-domain) -// ============================================================================ - -/// Sparse Route A oracle for RV32 packed SLL correctness: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) · 2^{shamt(t)} - val(t) - carry(t)·2^32) -/// -/// Where: -/// - `shamt(t)` is the shift amount (0..31), encoded as 5 Boolean columns, -/// - `carry(t)` is the high part of `(lhs << shamt)` as a u32 (range-checked separately), -/// - `val(t)` is the low 32 bits of `(lhs << shamt)`. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedSllOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - shamt_bits: Vec>, - carry_bits: Vec>, - val: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedSllOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - shamt_bits: Vec>, - carry_bits: Vec>, - val: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - debug_assert_eq!(shamt_bits.len(), 5); - for (i, b) in shamt_bits.iter().enumerate() { - debug_assert_eq!( - b.len(), - 1usize << ell_n, - "shamt_bits[{i}] length must match time domain" - ); - } - debug_assert_eq!(carry_bits.len(), 32); - for (i, b) in carry_bits.iter().enumerate() { - debug_assert_eq!( - b.len(), - 1usize << ell_n, - "carry_bits[{i}] length must match time domain" - ); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - shamt_bits, - carry_bits, - val, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - lhs(t): multilinear (degree 1) - // - 2^{shamt(t)}: product of 5 linear terms in the shamt bits (degree 5) - // - val(t), carry(t): multilinear (degree 1) - // ⇒ total degree ≤ 1 + 1 + max(1+5, 1, 1) = 8 - degree_bound: 8, - } - } -} - -impl RoundOracle for Rv32PackedSllOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); - let pow2_const: [K; 5] = [ - K::from_u64(2), - K::from_u64(4), - K::from_u64(16), - K::from_u64(256), - K::from_u64(65536), - ]; - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let val = self.val.singleton_value(); - - let mut pow2 = K::ONE; - for (b, c) in self.shamt_bits.iter().zip(pow2_const.iter()) { - let bit = b.singleton_value(); - pow2 *= K::ONE + bit * (*c - K::ONE); - } - - let mut carry = K::ZERO; - for (i, b) in self.carry_bits.iter().enumerate() { - carry += b.singleton_value() * K::from_u64(1u64 << i); - } - - let expr = lhs * pow2 - val - carry * two32; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. - let mut b0s: [K; 5] = [K::ZERO; 5]; - let mut b1s: [K; 5] = [K::ZERO; 5]; - for (i, b) in self.shamt_bits.iter().enumerate() { - b0s[i] = b.get(child0); - b1s[i] = b.get(child1); - } - let mut c0s: [K; 32] = [K::ZERO; 32]; - let mut c1s: [K; 32] = [K::ZERO; 32]; - for (i, b) in self.carry_bits.iter().enumerate() { - c0s[i] = b.get(child0); - c1s[i] = b.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let lhs_x = interp(lhs0, lhs1, x); - let val_x = interp(val0, val1, x); - - let mut pow2_x = K::ONE; - for j in 0..5 { - let b_x = interp(b0s[j], b1s[j], x); - pow2_x *= K::ONE + b_x * (pow2_const[j] - K::ONE); - } - - let mut carry_x = K::ZERO; - for j in 0..32 { - let c_x = interp(c0s[j], c1s[j], x); - carry_x += c_x * K::from_u64(1u64 << j); - } - - let expr_x = lhs_x * pow2_x - val_x - carry_x * two32; - if expr_x == K::ZERO { - continue; - } - - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - for b in self.shamt_bits.iter_mut() { - b.fold_round_in_place(r); - } - for b in self.carry_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -// ============================================================================ -// Packed-key RV32 SRL Shout (time-domain) -// ============================================================================ - -/// Sparse Route A oracle for RV32 packed SRL correctness: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - val(t)·2^{shamt(t)} - rem(t)) -/// -/// Where: -/// - `shamt(t)` is the shift amount (0..31), encoded as 5 Boolean columns, -/// - `rem(t)` is the remainder `lhs mod 2^{shamt}`, encoded as 32 Boolean columns, -/// - `val(t)` is the SRL result (`lhs >> shamt`). -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedSrlOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - shamt_bits: Vec>, - rem_bits: Vec>, - val: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedSrlOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - shamt_bits: Vec>, - rem_bits: Vec>, - val: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - debug_assert_eq!(shamt_bits.len(), 5); - for (i, b) in shamt_bits.iter().enumerate() { - debug_assert_eq!( - b.len(), - 1usize << ell_n, - "shamt_bits[{i}] length must match time domain" - ); - } - debug_assert_eq!(rem_bits.len(), 32); - for (i, b) in rem_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "rem_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - shamt_bits, - rem_bits, - val, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - lhs(t): multilinear (degree 1) - // - 2^{shamt(t)}: product of 5 linear terms in the shamt bits (degree 5) - // - val(t), rem(t): multilinear (degree 1) - // ⇒ total degree ≤ 1 + 1 + max(1+5, 1, 1) = 8 - degree_bound: 8, - } - } -} - -impl RoundOracle for Rv32PackedSrlOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let pow2_const: [K; 5] = [ - K::from_u64(2), - K::from_u64(4), - K::from_u64(16), - K::from_u64(256), - K::from_u64(65536), - ]; - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let val = self.val.singleton_value(); - - let mut pow2 = K::ONE; - for (b, c) in self.shamt_bits.iter().zip(pow2_const.iter()) { - let bit = b.singleton_value(); - pow2 *= K::ONE + bit * (*c - K::ONE); - } - - let mut rem = K::ZERO; - for (i, b) in self.rem_bits.iter().enumerate() { - rem += b.singleton_value() * K::from_u64(1u64 << i); - } - - let expr = lhs - val * pow2 - rem; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. - let mut b0s: [K; 5] = [K::ZERO; 5]; - let mut b1s: [K; 5] = [K::ZERO; 5]; - for (i, b) in self.shamt_bits.iter().enumerate() { - b0s[i] = b.get(child0); - b1s[i] = b.get(child1); - } - let mut r0s: [K; 32] = [K::ZERO; 32]; - let mut r1s: [K; 32] = [K::ZERO; 32]; - for (i, b) in self.rem_bits.iter().enumerate() { - r0s[i] = b.get(child0); - r1s[i] = b.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let lhs_x = interp(lhs0, lhs1, x); - let val_x = interp(val0, val1, x); - - let mut pow2_x = K::ONE; - for j in 0..5 { - let b_x = interp(b0s[j], b1s[j], x); - pow2_x *= K::ONE + b_x * (pow2_const[j] - K::ONE); - } - - let mut rem_x = K::ZERO; - for j in 0..32 { - let r_x = interp(r0s[j], r1s[j], x); - rem_x += r_x * K::from_u64(1u64 << j); - } - - let expr_x = lhs_x - val_x * pow2_x - rem_x; - if expr_x == K::ZERO { - continue; - } - - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - for b in self.shamt_bits.iter_mut() { - b.fold_round_in_place(r); - } - for b in self.rem_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle that enforces the SRL remainder is < 2^{shamt}: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · Σ_s eq(shamt(t), s) · Σ_{i≥s} 2^i · rem_bit_i(t) -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedSrlAdapterOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - shamt_bits: Vec>, - rem_bits: Vec>, - degree_bound: usize, -} - -impl Rv32PackedSrlAdapterOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - shamt_bits: Vec>, - rem_bits: Vec>, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(shamt_bits.len(), 5); - for (i, b) in shamt_bits.iter().enumerate() { - debug_assert_eq!( - b.len(), - 1usize << ell_n, - "shamt_bits[{i}] length must match time domain" - ); - } - debug_assert_eq!(rem_bits.len(), 32); - for (i, b) in rem_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "rem_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - shamt_bits, - rem_bits, - // Degree bound: chi (1) + gate (1) + eq_s(shamt) (5) + tail(rem) (1) = 8. - degree_bound: 8, - } - } -} - -impl RoundOracle for Rv32PackedSrlAdapterOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - - let mut shamt: [K; 5] = [K::ZERO; 5]; - for (i, b) in self.shamt_bits.iter().enumerate() { - shamt[i] = b.singleton_value(); - } - - let mut rem: [K; 32] = [K::ZERO; 32]; - for (i, b) in self.rem_bits.iter().enumerate() { - rem[i] = b.singleton_value(); - } - - // tail_sum[s] = Σ_{i≥s} 2^i · rem_i - let mut tail_sum: [K; 32] = [K::ZERO; 32]; - let mut tail = K::ZERO; - for i in (0..32).rev() { - tail += rem[i] * K::from_u64(1u64 << i); - tail_sum[i] = tail; - } - - // eq_s(shamt) for s in 0..32 - let mut eq: [K; 32] = [K::ZERO; 32]; - for s in 0..32usize { - let mut prod = K::ONE; - for j in 0..5usize { - let b = shamt[j]; - if ((s >> j) & 1) == 1 { - prod *= b; - } else { - prod *= K::ONE - b; - } - } - eq[s] = prod; - } - - let mut expr = K::ZERO; - for s in 0..32usize { - expr += eq[s] * tail_sum[s]; - } - - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. - let mut b0s: [K; 5] = [K::ZERO; 5]; - let mut b1s: [K; 5] = [K::ZERO; 5]; - for (i, b) in self.shamt_bits.iter().enumerate() { - b0s[i] = b.get(child0); - b1s[i] = b.get(child1); - } - let mut r0s: [K; 32] = [K::ZERO; 32]; - let mut r1s: [K; 32] = [K::ZERO; 32]; - for (i, b) in self.rem_bits.iter().enumerate() { - r0s[i] = b.get(child0); - r1s[i] = b.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let mut shamt: [K; 5] = [K::ZERO; 5]; - for j in 0..5usize { - shamt[j] = interp(b0s[j], b1s[j], x); - } - - let mut rem: [K; 32] = [K::ZERO; 32]; - for j in 0..32usize { - rem[j] = interp(r0s[j], r1s[j], x); - } - - // tail_sum[s] = Σ_{i≥s} 2^i · rem_i - let mut tail_sum: [K; 32] = [K::ZERO; 32]; - let mut tail = K::ZERO; - for j in (0..32).rev() { - tail += rem[j] * K::from_u64(1u64 << j); - tail_sum[j] = tail; - } - - // eq_s(shamt) for s in 0..32 - let mut expr_x = K::ZERO; - for s in 0..32usize { - let mut prod = K::ONE; - for j in 0..5usize { - let b = shamt[j]; - if ((s >> j) & 1) == 1 { - prod *= b; - } else { - prod *= K::ONE - b; - } - } - expr_x += prod * tail_sum[s]; - } - - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - for b in self.shamt_bits.iter_mut() { - b.fold_round_in_place(r); - } - for b in self.rem_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -// ============================================================================ -// Packed-key RV32 SRA Shout (time-domain) -// ============================================================================ - -/// Sparse Route A oracle for RV32 packed SRA correctness: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (lhs(t) - val(t)·2^{shamt(t)} - rem(t) - sign(t)·2^32·(1 - 2^{shamt(t)})) -/// -/// Where: -/// - `shamt(t)` is the shift amount (0..31), encoded as 5 Boolean columns, -/// - `sign(t)` is the sign bit of `lhs` (and the expected sign bit of `val`), encoded as 1 Boolean column, -/// - `rem(t)` is the remainder in the signed floor-division identity: -/// (lhs_signed) = (val_signed)·2^{shamt} + rem, with rem ∈ [0, 2^{shamt}) -/// encoded as 31 Boolean columns (bits 0..30), -/// - `val(t)` is the SRA result (`(lhs as i32) >> shamt`, represented as u32 in the field). -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedSraOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - shamt_bits: Vec>, - sign: SparseIdxVec, - rem_bits: Vec>, - val: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedSraOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - shamt_bits: Vec>, - sign: SparseIdxVec, - rem_bits: Vec>, - val: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - debug_assert_eq!(sign.len(), 1usize << ell_n); - debug_assert_eq!(shamt_bits.len(), 5); - for (i, b) in shamt_bits.iter().enumerate() { - debug_assert_eq!( - b.len(), - 1usize << ell_n, - "shamt_bits[{i}] length must match time domain" - ); - } - debug_assert_eq!(rem_bits.len(), 31); - for (i, b) in rem_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "rem_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - shamt_bits, - sign, - rem_bits, - val, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - lhs(t): multilinear (degree 1) - // - 2^{shamt(t)}: product of 5 linear terms in the shamt bits (degree 5) - // - val(t), rem(t), sign(t): multilinear (degree 1) - // ⇒ total degree ≤ 1 + 1 + max(1+5, 1, 1+5, 1) = 8 - degree_bound: 8, - } - } -} - -impl RoundOracle for Rv32PackedSraOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); - let pow2_const: [K; 5] = [ - K::from_u64(2), - K::from_u64(4), - K::from_u64(16), - K::from_u64(256), - K::from_u64(65536), - ]; - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let val = self.val.singleton_value(); - let sign = self.sign.singleton_value(); - - let mut pow2 = K::ONE; - for (b, c) in self.shamt_bits.iter().zip(pow2_const.iter()) { - let bit = b.singleton_value(); - pow2 *= K::ONE + bit * (*c - K::ONE); - } - - let mut rem = K::ZERO; - for (i, b) in self.rem_bits.iter().enumerate() { - rem += b.singleton_value() * K::from_u64(1u64 << i); - } - - let corr = sign * two32 * (K::ONE - pow2); - let expr = lhs - val * pow2 - rem - corr; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - let sign0 = self.sign.get(child0); - let sign1 = self.sign.get(child1); - - // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. - let mut b0s: [K; 5] = [K::ZERO; 5]; - let mut b1s: [K; 5] = [K::ZERO; 5]; - for (i, b) in self.shamt_bits.iter().enumerate() { - b0s[i] = b.get(child0); - b1s[i] = b.get(child1); - } - let mut r0s: [K; 31] = [K::ZERO; 31]; - let mut r1s: [K; 31] = [K::ZERO; 31]; - for (i, b) in self.rem_bits.iter().enumerate() { - r0s[i] = b.get(child0); - r1s[i] = b.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let lhs_x = interp(lhs0, lhs1, x); - let val_x = interp(val0, val1, x); - let sign_x = interp(sign0, sign1, x); - - let mut pow2_x = K::ONE; - for j in 0..5 { - let b_x = interp(b0s[j], b1s[j], x); - pow2_x *= K::ONE + b_x * (pow2_const[j] - K::ONE); - } - - let mut rem_x = K::ZERO; - for j in 0..31 { - let r_x = interp(r0s[j], r1s[j], x); - rem_x += r_x * K::from_u64(1u64 << j); - } - - let corr_x = sign_x * two32 * (K::ONE - pow2_x); - let expr_x = lhs_x - val_x * pow2_x - rem_x - corr_x; - if expr_x == K::ZERO { - continue; - } - - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - for b in self.shamt_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.sign.fold_round_in_place(r); - for b in self.rem_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle that enforces the SRA remainder is < 2^{shamt}: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · Σ_s eq(shamt(t), s) · Σ_{i≥s} 2^i · rem_bit_i(t) -/// -/// For SRA, we only carry 31 remainder bits (0..30). This is sufficient because `shamt ∈ [0,31]` -/// and the remainder range is always `< 2^{shamt} ≤ 2^{31}`. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedSraAdapterOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - shamt_bits: Vec>, - rem_bits: Vec>, - degree_bound: usize, -} - -impl Rv32PackedSraAdapterOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - shamt_bits: Vec>, - rem_bits: Vec>, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(shamt_bits.len(), 5); - for (i, b) in shamt_bits.iter().enumerate() { - debug_assert_eq!( - b.len(), - 1usize << ell_n, - "shamt_bits[{i}] length must match time domain" - ); - } - debug_assert_eq!(rem_bits.len(), 31); - for (i, b) in rem_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "rem_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - shamt_bits, - rem_bits, - // Degree bound: chi (1) + gate (1) + eq_s(shamt) (5) + tail(rem) (1) = 8. - degree_bound: 8, - } - } -} - -impl RoundOracle for Rv32PackedSraAdapterOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - - let mut shamt: [K; 5] = [K::ZERO; 5]; - for (i, b) in self.shamt_bits.iter().enumerate() { - shamt[i] = b.singleton_value(); - } - - let mut rem: [K; 31] = [K::ZERO; 31]; - for (i, b) in self.rem_bits.iter().enumerate() { - rem[i] = b.singleton_value(); - } - - // tail_sum[s] = Σ_{i≥s} 2^i · rem_i, with tail_sum[31]=0 (no bits >= 31). - let mut tail_sum: [K; 32] = [K::ZERO; 32]; - let mut tail = K::ZERO; - for i in (0..31).rev() { - tail += rem[i] * K::from_u64(1u64 << i); - tail_sum[i] = tail; - } - tail_sum[31] = K::ZERO; - - let mut expr = K::ZERO; - for s in 0..32usize { - let mut prod = K::ONE; - for j in 0..5usize { - let b = shamt[j]; - if ((s >> j) & 1) == 1 { - prod *= b; - } else { - prod *= K::ONE - b; - } - } - expr += prod * tail_sum[s]; - } - - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - // Pre-fetch bit endpoints for this pair to avoid repeated sparse lookups per eval point. - let mut b0s: [K; 5] = [K::ZERO; 5]; - let mut b1s: [K; 5] = [K::ZERO; 5]; - for (i, b) in self.shamt_bits.iter().enumerate() { - b0s[i] = b.get(child0); - b1s[i] = b.get(child1); - } - let mut r0s: [K; 31] = [K::ZERO; 31]; - let mut r1s: [K; 31] = [K::ZERO; 31]; - for (i, b) in self.rem_bits.iter().enumerate() { - r0s[i] = b.get(child0); - r1s[i] = b.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let mut shamt: [K; 5] = [K::ZERO; 5]; - for j in 0..5usize { - shamt[j] = interp(b0s[j], b1s[j], x); - } - - let mut rem: [K; 31] = [K::ZERO; 31]; - for j in 0..31usize { - rem[j] = interp(r0s[j], r1s[j], x); - } - - // tail_sum[s] = Σ_{i≥s} 2^i · rem_i, with tail_sum[31]=0 (no bits >= 31). - let mut tail_sum: [K; 32] = [K::ZERO; 32]; - let mut tail = K::ZERO; - for j in (0..31).rev() { - tail += rem[j] * K::from_u64(1u64 << j); - tail_sum[j] = tail; - } - tail_sum[31] = K::ZERO; - - let mut expr_x = K::ZERO; - for s in 0..32usize { - let mut prod = K::ONE; - for j in 0..5usize { - let b = shamt[j]; - if ((s >> j) & 1) == 1 { - prod *= b; - } else { - prod *= K::ONE - b; - } - } - expr_x += prod * tail_sum[s]; - } - - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - for b in self.shamt_bits.iter_mut() { - b.fold_round_in_place(r); - } - for b in self.rem_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -// ============================================================================ -// Packed-key RV32 DIV*/REM* Shout (time-domain) -// ============================================================================ - -/// Sparse Route A oracle for RV32 packed DIVU correctness: -/// Σ_t χ(t)·has_lookup(t)·( z(t)·(quot(t) - 0xFFFF_FFFF) + (1 - z(t))·(lhs(t) - rhs(t)·quot(t) - rem(t)) ) -/// -/// Where: -/// - `quot(t)` is the DIVU output (lane.val), -/// - `rem(t)` is an auxiliary witness column, -/// - `z(t)` is a witness bit intended to be 1 iff `rhs(t) == 0`. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedDivuOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - rem: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - quot: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedDivuOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - rem: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - quot: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(rem.len(), 1usize << ell_n); - debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(quot.len(), 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - rhs, - rem, - rhs_is_zero, - quot, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - rhs(t)·quot(t): degree 2 - // - gated by (1 - z(t)): adds 1 - // ⇒ total degree ≤ 1 + 1 + 2 + 1 = 5 - degree_bound: 5, - } - } -} - -impl RoundOracle for Rv32PackedDivuOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let all_ones = K::from_u64(u32::MAX as u64); - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let rem = self.rem.singleton_value(); - let z = self.rhs_is_zero.singleton_value(); - let quot = self.quot.singleton_value(); - - let expr = z * (quot - all_ones) + (K::ONE - z) * (lhs - rhs * quot - rem); - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let rem0 = self.rem.get(child0); - let rem1 = self.rem.get(child1); - let z0 = self.rhs_is_zero.get(child0); - let z1 = self.rhs_is_zero.get(child1); - let quot0 = self.quot.get(child0); - let quot1 = self.quot.get(child1); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let lhs_x = interp(lhs0, lhs1, x); - let rhs_x = interp(rhs0, rhs1, x); - let rem_x = interp(rem0, rem1, x); - let z_x = interp(z0, z1, x); - let quot_x = interp(quot0, quot1, x); - - let expr_x = z_x * (quot_x - all_ones) + (K::ONE - z_x) * (lhs_x - rhs_x * quot_x - rem_x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.rem.fold_round_in_place(r); - self.rhs_is_zero.fold_round_in_place(r); - self.quot.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed REMU correctness: -/// Σ_t χ(t)·has_lookup(t)·( z(t)·(rem(t) - lhs(t)) + (1 - z(t))·(lhs(t) - rhs(t)·quot(t) - rem(t)) ) -/// -/// Where: -/// - `rem(t)` is the REMU output (lane.val), -/// - `quot(t)` is an auxiliary witness column, -/// - `z(t)` is a witness bit intended to be 1 iff `rhs(t) == 0`. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedRemuOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - quot: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - rem: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedRemuOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - quot: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - rem: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(quot.len(), 1usize << ell_n); - debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(rem.len(), 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - rhs, - quot, - rhs_is_zero, - rem, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - rhs(t)·quot(t): degree 2 - // - gated by (1 - z(t)): adds 1 - // ⇒ total degree ≤ 1 + 1 + 2 + 1 = 5 - degree_bound: 5, - } - } -} - -impl RoundOracle for Rv32PackedRemuOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let quot = self.quot.singleton_value(); - let z = self.rhs_is_zero.singleton_value(); - let rem = self.rem.singleton_value(); - - let expr = z * (rem - lhs) + (K::ONE - z) * (lhs - rhs * quot - rem); - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let quot0 = self.quot.get(child0); - let quot1 = self.quot.get(child1); - let z0 = self.rhs_is_zero.get(child0); - let z1 = self.rhs_is_zero.get(child1); - let rem0 = self.rem.get(child0); - let rem1 = self.rem.get(child1); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let lhs_x = interp(lhs0, lhs1, x); - let rhs_x = interp(rhs0, rhs1, x); - let quot_x = interp(quot0, quot1, x); - let z_x = interp(z0, z1, x); - let rem_x = interp(rem0, rem1, x); - - let expr_x = z_x * (rem_x - lhs_x) + (K::ONE - z_x) * (lhs_x - rhs_x * quot_x - rem_x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.quot.fold_round_in_place(r); - self.rhs_is_zero.fold_round_in_place(r); - self.rem.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed DIVU/REMU helpers (adapter): -/// Σ_t χ(t)·has_lookup(t)·Σ_i w_i · c_i(t) -/// -/// Constraints: -/// - `c0 = rhs_is_zero·(1 - rhs_is_zero)` (boolean helper; redundant with bitness) -/// - `c1 = rhs_is_zero·rhs` (rhs_is_zero => rhs==0) -/// - `c2 = (1 - rhs_is_zero)·(rem - rhs - diff + 2^32)` (remainder bound) -/// - `c3 = diff - Σ 2^i·diff_bit_i` (u32 decomposition) -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedDivRemuAdapterOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - rhs: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - rem: SparseIdxVec, - diff: SparseIdxVec, - diff_bits: Vec>, - weights: [K; 4], - degree_bound: usize, -} - -impl Rv32PackedDivRemuAdapterOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - rhs: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - rem: SparseIdxVec, - diff: SparseIdxVec, - diff_bits: Vec>, - weights: [K; 4], - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(rem.len(), 1usize << ell_n); - debug_assert_eq!(diff.len(), 1usize << ell_n); - debug_assert_eq!(diff_bits.len(), 32); - for (i, b) in diff_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - rhs, - rhs_is_zero, - rem, - diff, - diff_bits, - weights, - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - remainder bound term multiplies by (1 - rhs_is_zero): degree 2 - // ⇒ total degree ≤ 1 + 1 + 2 = 4 - degree_bound: 4, - } - } -} - -impl RoundOracle for Rv32PackedDivRemuAdapterOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); - let w0 = self.weights[0]; - let w1 = self.weights[1]; - let w2 = self.weights[2]; - let w3 = self.weights[3]; - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let rhs = self.rhs.singleton_value(); - let z = self.rhs_is_zero.singleton_value(); - let rem = self.rem.singleton_value(); - let diff = self.diff.singleton_value(); - - let mut sum = K::ZERO; - for (i, b) in self.diff_bits.iter().enumerate() { - sum += b.singleton_value() * K::from_u64(1u64 << i); - } - - let c0 = z * (K::ONE - z); - let c1 = z * rhs; - let c2 = (K::ONE - z) * (rem - rhs - diff + two32); - let c3 = diff - sum; - - let expr = w0 * c0 + w1 * c1 + w2 * c2 + w3 * c3; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let z0 = self.rhs_is_zero.get(child0); - let z1 = self.rhs_is_zero.get(child1); - let rem0 = self.rem.get(child0); - let rem1 = self.rem.get(child1); - let diff0 = self.diff.get(child0); - let diff1 = self.diff.get(child1); - - let mut b0s: [K; 32] = [K::ZERO; 32]; - let mut b1s: [K; 32] = [K::ZERO; 32]; - for (j, b) in self.diff_bits.iter().enumerate() { - b0s[j] = b.get(child0); - b1s[j] = b.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let rhs_x = interp(rhs0, rhs1, x); - let z_x = interp(z0, z1, x); - let rem_x = interp(rem0, rem1, x); - let diff_x = interp(diff0, diff1, x); - - let mut sum = K::ZERO; - for j in 0..32 { - let b_x = interp(b0s[j], b1s[j], x); - sum += b_x * K::from_u64(1u64 << j); - } - - let c0 = z_x * (K::ONE - z_x); - let c1 = z_x * rhs_x; - let c2 = (K::ONE - z_x) * (rem_x - rhs_x - diff_x + two32); - let c3 = diff_x - sum; - let expr_x = w0 * c0 + w1 * c1 + w2 * c2 + w3 * c3; - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.rhs_is_zero.fold_round_in_place(r); - self.rem.fold_round_in_place(r); - self.diff.fold_round_in_place(r); - for b in self.diff_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed DIV correctness (signed quotient output). -/// -/// Uses auxiliary `q_abs` (unsigned quotient), sign bits, and a `q_is_zero` witness bit to -/// handle the `-0` normalization case. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedDivOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_sign: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - q_abs: SparseIdxVec, - q_is_zero: SparseIdxVec, - val: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedDivOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_sign: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - q_abs: SparseIdxVec, - q_is_zero: SparseIdxVec, - val: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(rhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(q_abs.len(), 1usize << ell_n); - debug_assert_eq!(q_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs_sign, - rhs_sign, - rhs_is_zero, - q_abs, - q_is_zero, - val, - // This oracle composes abs-quotient sign logic (degree 4) with rhs_is_zero gating. - degree_bound: 7, - } - } -} - -impl RoundOracle for Rv32PackedDivOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let two = K::from_u64(2); - let two32 = K::from_u64(1u64 << 32); - let all_ones = K::from_u64(u32::MAX as u64); - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let s1 = self.lhs_sign.singleton_value(); - let s2 = self.rhs_sign.singleton_value(); - let z = self.rhs_is_zero.singleton_value(); - let q_abs = self.q_abs.singleton_value(); - let q0 = self.q_is_zero.singleton_value(); - let val = self.val.singleton_value(); - - let div_sign = s1 + s2 - two * s1 * s2; - let neg_q = (K::ONE - q0) * (two32 - q_abs); - let q_signed = (K::ONE - div_sign) * q_abs + div_sign * neg_q; - - let expr = z * (val - all_ones) + (K::ONE - z) * (val - q_signed); - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let s10 = self.lhs_sign.get(child0); - let s11 = self.lhs_sign.get(child1); - let s20 = self.rhs_sign.get(child0); - let s21 = self.rhs_sign.get(child1); - let z0 = self.rhs_is_zero.get(child0); - let z1 = self.rhs_is_zero.get(child1); - let q0 = self.q_abs.get(child0); - let q1 = self.q_abs.get(child1); - let qz0 = self.q_is_zero.get(child0); - let qz1 = self.q_is_zero.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let s1 = interp(s10, s11, x); - let s2 = interp(s20, s21, x); - let z = interp(z0, z1, x); - let q_abs = interp(q0, q1, x); - let qz = interp(qz0, qz1, x); - let val = interp(val0, val1, x); - - let div_sign = s1 + s2 - two * s1 * s2; - let neg_q = (K::ONE - qz) * (two32 - q_abs); - let q_signed = (K::ONE - div_sign) * q_abs + div_sign * neg_q; - - let expr_x = z * (val - all_ones) + (K::ONE - z) * (val - q_signed); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs_sign.fold_round_in_place(r); - self.rhs_sign.fold_round_in_place(r); - self.rhs_is_zero.fold_round_in_place(r); - self.q_abs.fold_round_in_place(r); - self.q_is_zero.fold_round_in_place(r); - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed REM correctness (signed remainder output). -/// -/// Uses auxiliary `r_abs` (unsigned remainder), the dividend sign bit, and `r_is_zero` to handle `-0`. -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct Rv32PackedRemOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - r_abs: SparseIdxVec, - r_is_zero: SparseIdxVec, - val: SparseIdxVec, - degree_bound: usize, -} - -impl Rv32PackedRemOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - r_abs: SparseIdxVec, - r_is_zero: SparseIdxVec, - val: SparseIdxVec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(r_abs.len(), 1usize << ell_n); - debug_assert_eq!(r_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - lhs_sign, - rhs_is_zero, - r_abs, - r_is_zero, - val, - degree_bound: 7, - } - } -} - -impl RoundOracle for Rv32PackedRemOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let two32 = K::from_u64(1u64 << 32); - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let s = self.lhs_sign.singleton_value(); - let z = self.rhs_is_zero.singleton_value(); - let r_abs = self.r_abs.singleton_value(); - let r0 = self.r_is_zero.singleton_value(); - let val = self.val.singleton_value(); - - let neg_r = (K::ONE - r0) * (two32 - r_abs); - let r_signed = (K::ONE - s) * r_abs + s * neg_r; - let expr = z * (val - lhs) + (K::ONE - z) * (val - r_signed); - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let s0 = self.lhs_sign.get(child0); - let s1 = self.lhs_sign.get(child1); - let z0 = self.rhs_is_zero.get(child0); - let z1 = self.rhs_is_zero.get(child1); - let r0_0 = self.r_abs.get(child0); - let r0_1 = self.r_abs.get(child1); - let rz0 = self.r_is_zero.get(child0); - let rz1 = self.r_is_zero.get(child1); - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let lhs_x = interp(lhs0, lhs1, x); - let s_x = interp(s0, s1, x); - let z_x = interp(z0, z1, x); - let r_abs_x = interp(r0_0, r0_1, x); - let rz_x = interp(rz0, rz1, x); - let val_x = interp(val0, val1, x); - - let neg_r = (K::ONE - rz_x) * (two32 - r_abs_x); - let r_signed = (K::ONE - s_x) * r_abs_x + s_x * neg_r; - let expr_x = z_x * (val_x - lhs_x) + (K::ONE - z_x) * (val_x - r_signed); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.lhs_sign.fold_round_in_place(r); - self.rhs_is_zero.fold_round_in_place(r); - self.r_abs.fold_round_in_place(r); - self.r_is_zero.fold_round_in_place(r); - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for RV32 packed DIV/REM helpers (adapter). -/// -/// This enforces: -/// - zero-detection for `rhs`, -/// - zero-detection for a signed-output magnitude (`q_abs` for DIV or `r_abs` for REM) to handle `-0`, -/// - absolute-value division equation over u32 values when `rhs != 0`, -/// - remainder bound `r_abs < |rhs|` when `rhs != 0`, -/// - u32 decomposition of the remainder-bound witness `diff`. -pub struct Rv32PackedDivRemAdapterOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_sign: SparseIdxVec, - q_abs: SparseIdxVec, - r_abs: SparseIdxVec, - mag: SparseIdxVec, - mag_is_zero: SparseIdxVec, - diff: SparseIdxVec, - diff_bits: Vec>, - weights: [K; 7], - degree_bound: usize, -} - -impl Rv32PackedDivRemAdapterOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - rhs_is_zero: SparseIdxVec, - lhs_sign: SparseIdxVec, - rhs_sign: SparseIdxVec, - q_abs: SparseIdxVec, - r_abs: SparseIdxVec, - mag: SparseIdxVec, - mag_is_zero: SparseIdxVec, - diff: SparseIdxVec, - diff_bits: Vec>, - weights: [K; 7], - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(lhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(rhs_sign.len(), 1usize << ell_n); - debug_assert_eq!(q_abs.len(), 1usize << ell_n); - debug_assert_eq!(r_abs.len(), 1usize << ell_n); - debug_assert_eq!(mag.len(), 1usize << ell_n); - debug_assert_eq!(mag_is_zero.len(), 1usize << ell_n); - debug_assert_eq!(diff.len(), 1usize << ell_n); - debug_assert_eq!(diff_bits.len(), 32); - for (i, b) in diff_bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "diff_bits[{i}] length must match time domain"); - } - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs, - rhs, - rhs_is_zero, - lhs_sign, - rhs_sign, - q_abs, - r_abs, - mag, - mag_is_zero, - diff, - diff_bits, - weights, - degree_bound: 6, - } - } -} - -impl RoundOracle for Rv32PackedDivRemAdapterOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let two = K::from_u64(2); - let two32 = K::from_u64(1u64 << 32); - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - let z = self.rhs_is_zero.singleton_value(); - let lhs_sign = self.lhs_sign.singleton_value(); - let rhs_sign = self.rhs_sign.singleton_value(); - let q_abs = self.q_abs.singleton_value(); - let r_abs = self.r_abs.singleton_value(); - let mag = self.mag.singleton_value(); - let mag_z = self.mag_is_zero.singleton_value(); - let diff = self.diff.singleton_value(); - - let mut sum = K::ZERO; - for (i, b) in self.diff_bits.iter().enumerate() { - sum += b.singleton_value() * K::from_u64(1u64 << i); - } - - let lhs_abs = lhs + lhs_sign * (two32 - two * lhs); - let rhs_abs = rhs + rhs_sign * (two32 - two * rhs); - - let c0 = z * (K::ONE - z); - let c1 = z * rhs; - let c2 = mag_z * (K::ONE - mag_z); - let c3 = mag_z * mag; - let c4 = (K::ONE - z) * (lhs_abs - rhs_abs * q_abs - r_abs); - let c5 = (K::ONE - z) * (r_abs - rhs_abs - diff + two32); - let c6 = diff - sum; - - let w = &self.weights; - let expr = w[0] * c0 + w[1] * c1 + w[2] * c2 + w[3] * c3 + w[4] * c4 + w[5] * c5 + w[6] * c6; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - let z0 = self.rhs_is_zero.get(child0); - let z1 = self.rhs_is_zero.get(child1); - let lhs_sign0 = self.lhs_sign.get(child0); - let lhs_sign1 = self.lhs_sign.get(child1); - let rhs_sign0 = self.rhs_sign.get(child0); - let rhs_sign1 = self.rhs_sign.get(child1); - let q0 = self.q_abs.get(child0); - let q1 = self.q_abs.get(child1); - let r0 = self.r_abs.get(child0); - let r1 = self.r_abs.get(child1); - let mag0 = self.mag.get(child0); - let mag1 = self.mag.get(child1); - let mag_z0 = self.mag_is_zero.get(child0); - let mag_z1 = self.mag_is_zero.get(child1); - let diff0 = self.diff.get(child0); - let diff1 = self.diff.get(child1); - - let mut b0s: [K; 32] = [K::ZERO; 32]; - let mut b1s: [K; 32] = [K::ZERO; 32]; - for (j, b) in self.diff_bits.iter().enumerate() { - b0s[j] = b.get(child0); - b1s[j] = b.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let lhs = interp(lhs0, lhs1, x); - let rhs = interp(rhs0, rhs1, x); - let z = interp(z0, z1, x); - let lhs_sign = interp(lhs_sign0, lhs_sign1, x); - let rhs_sign = interp(rhs_sign0, rhs_sign1, x); - let q_abs = interp(q0, q1, x); - let r_abs = interp(r0, r1, x); - let mag = interp(mag0, mag1, x); - let mag_z = interp(mag_z0, mag_z1, x); - let diff = interp(diff0, diff1, x); - - let mut sum = K::ZERO; - for j in 0..32 { - let b_x = interp(b0s[j], b1s[j], x); - sum += b_x * K::from_u64(1u64 << j); - } - - let lhs_abs = lhs + lhs_sign * (two32 - two * lhs); - let rhs_abs = rhs + rhs_sign * (two32 - two * rhs); - - let c0 = z * (K::ONE - z); - let c1 = z * rhs; - let c2 = mag_z * (K::ONE - mag_z); - let c3 = mag_z * mag; - let c4 = (K::ONE - z) * (lhs_abs - rhs_abs * q_abs - r_abs); - let c5 = (K::ONE - z) * (r_abs - rhs_abs - diff + two32); - let c6 = diff - sum; - - let w = &self.weights; - let expr_x = w[0] * c0 + w[1] * c1 + w[2] * c2 + w[3] * c3 + w[4] * c4 + w[5] * c5 + w[6] * c6; - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - self.rhs_is_zero.fold_round_in_place(r); - self.lhs_sign.fold_round_in_place(r); - self.rhs_sign.fold_round_in_place(r); - self.q_abs.fold_round_in_place(r); - self.r_abs.fold_round_in_place(r); - self.mag.fold_round_in_place(r); - self.mag_is_zero.fold_round_in_place(r); - self.diff.fold_round_in_place(r); - for b in self.diff_bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -// ============================================================================ -// Packed-key RV32 bitwise Shout (AND/OR/XOR) via 2-bit digits (time-domain) -// ============================================================================ - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum Rv32PackedBitwiseOp2 { - And, - Andn, - Or, - Xor, -} - -#[inline] fn rv32_two_bit_digit_bits(inv2: K, inv6: K, x: K) -> (K, K) { - // Bits for x in {0,1,2,3}, represented as degree-3 polynomials over the field: - // - bit0: 0,1,0,1 - // - bit1: 0,0,1,1 - // - // Using Lagrange basis on {0,1,2,3}: - // L1(x) = x(x-2)(x-3)/2 - // L2(x) = -x(x-1)(x-3)/2 - // L3(x) = x(x-1)(x-2)/6 - // - // Then: - // bit0(x) = L1(x) + L3(x) - // bit1(x) = L2(x) + L3(x) let xm1 = x - K::ONE; let xm2 = x - K::from_u64(2); - let xm3 = x - K::from_u64(3); - - let x_xm1 = x * xm1; - let l1 = (x * xm2 * xm3) * inv2; - let l3 = (x_xm1 * xm2) * inv6; - let l2 = -(x_xm1 * xm3) * inv2; - - let bit0 = l1 + l3; - let bit1 = l2 + l3; - (bit0, bit1) -} - -#[inline] -fn rv32_two_bit_digit_op(inv2: K, inv6: K, op: Rv32PackedBitwiseOp2, a: K, b: K) -> K { - let (a0, a1) = rv32_two_bit_digit_bits(inv2, inv6, a); - let (b0, b1) = rv32_two_bit_digit_bits(inv2, inv6, b); - - let two = K::from_u64(2); - match op { - Rv32PackedBitwiseOp2::And => { - let r0 = a0 * b0; - let r1 = a1 * b1; - r0 + two * r1 - } - Rv32PackedBitwiseOp2::Andn => { - let r0 = a0 * (K::ONE - b0); - let r1 = a1 * (K::ONE - b1); - r0 + two * r1 - } - Rv32PackedBitwiseOp2::Or => { - let r0 = a0 + b0 - a0 * b0; - let r1 = a1 + b1 - a1 * b1; - r0 + two * r1 - } - Rv32PackedBitwiseOp2::Xor => { - let r0 = a0 + b0 - two * a0 * b0; - let r1 = a1 + b1 - two * a1 * b1; - r0 + two * r1 - } - } -} - -#[inline] -fn rv32_digit4_range_poly(x: K) -> K { - // Vanishes exactly on {0,1,2,3}. - x * (x - K::ONE) * (x - K::from_u64(2)) * (x - K::from_u64(3)) -} - -pub struct Rv32PackedBitwiseAdapterOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - degree_bound: usize, - - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lhs_digits: Vec>, - rhs_digits: Vec>, - weights: Vec, -} - -impl Rv32PackedBitwiseAdapterOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs: SparseIdxVec, - rhs: SparseIdxVec, - lhs_digits: Vec>, - rhs_digits: Vec>, - weights: Vec, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(lhs.len(), 1usize << ell_n); - debug_assert_eq!(rhs.len(), 1usize << ell_n); - debug_assert_eq!(lhs_digits.len(), 16); - debug_assert_eq!(rhs_digits.len(), 16); - for (i, d) in lhs_digits.iter().enumerate() { - debug_assert_eq!( - d.len(), - 1usize << ell_n, - "lhs_digits[{i}] length must match time domain" - ); - } - for (i, d) in rhs_digits.iter().enumerate() { - debug_assert_eq!( - d.len(), - 1usize << ell_n, - "rhs_digits[{i}] length must match time domain" - ); - } - debug_assert_eq!(weights.len(), 34); - - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - digit4 range poly: degree 4 - // ⇒ total degree ≤ 1 + 1 + 4 = 6 - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - degree_bound: 6, - has_lookup, - lhs, - rhs, - lhs_digits, - rhs_digits, - weights, - } - } -} - -impl RoundOracle for Rv32PackedBitwiseAdapterOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - let w_lhs = self.weights[0]; - let w_rhs = self.weights[1]; - let w_digits = &self.weights[2..]; - - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let lhs = self.lhs.singleton_value(); - let rhs = self.rhs.singleton_value(); - - let mut lhs_recon = K::ZERO; - let mut rhs_recon = K::ZERO; - for i in 0..16usize { - lhs_recon += self.lhs_digits[i].singleton_value() * K::from_u64(1u64 << (2 * i)); - rhs_recon += self.rhs_digits[i].singleton_value() * K::from_u64(1u64 << (2 * i)); - } - - let mut range_sum = K::ZERO; - for (i, d) in self.lhs_digits.iter().enumerate() { - range_sum += w_digits[i] * rv32_digit4_range_poly(d.singleton_value()); - } - for (i, d) in self.rhs_digits.iter().enumerate() { - range_sum += w_digits[16 + i] * rv32_digit4_range_poly(d.singleton_value()); - } - - let expr = w_lhs * (lhs - lhs_recon) + w_rhs * (rhs - rhs_recon) + range_sum; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let lhs0 = self.lhs.get(child0); - let lhs1 = self.lhs.get(child1); - let rhs0 = self.rhs.get(child0); - let rhs1 = self.rhs.get(child1); - - let mut a0s: [K; 16] = [K::ZERO; 16]; - let mut a1s: [K; 16] = [K::ZERO; 16]; - for (i, d) in self.lhs_digits.iter().enumerate() { - a0s[i] = d.get(child0); - a1s[i] = d.get(child1); - } - let mut b0s: [K; 16] = [K::ZERO; 16]; - let mut b1s: [K; 16] = [K::ZERO; 16]; - for (i, d) in self.rhs_digits.iter().enumerate() { - b0s[i] = d.get(child0); - b1s[i] = d.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let lhs_x = interp(lhs0, lhs1, x); - let rhs_x = interp(rhs0, rhs1, x); - - let mut lhs_recon = K::ZERO; - let mut rhs_recon = K::ZERO; - let mut range_sum = K::ZERO; - for j in 0..16usize { - let aj = interp(a0s[j], a1s[j], x); - let bj = interp(b0s[j], b1s[j], x); - let pow = K::from_u64(1u64 << (2 * j)); - lhs_recon += aj * pow; - rhs_recon += bj * pow; - - range_sum += w_digits[j] * rv32_digit4_range_poly(aj); - range_sum += w_digits[16 + j] * rv32_digit4_range_poly(bj); - } - - let expr = w_lhs * (lhs_x - lhs_recon) + w_rhs * (rhs_x - rhs_recon) + range_sum; - ys[i] += chi_x * gate_x * expr; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.lhs.fold_round_in_place(r); - self.rhs.fold_round_in_place(r); - for d in self.lhs_digits.iter_mut() { - d.fold_round_in_place(r); - } - for d in self.rhs_digits.iter_mut() { - d.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -pub struct Rv32PackedBitwiseOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - lhs_digits: Vec>, - rhs_digits: Vec>, - val: SparseIdxVec, - op: Rv32PackedBitwiseOp2, - inv2: K, - inv6: K, - degree_bound: usize, -} - -impl Rv32PackedBitwiseOracleSparseTime { - fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs_digits: Vec>, - rhs_digits: Vec>, - val: SparseIdxVec, - op: Rv32PackedBitwiseOp2, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - debug_assert_eq!(lhs_digits.len(), 16); - debug_assert_eq!(rhs_digits.len(), 16); - for (i, d) in lhs_digits.iter().enumerate() { - debug_assert_eq!( - d.len(), - 1usize << ell_n, - "lhs_digits[{i}] length must match time domain" - ); - } - for (i, d) in rhs_digits.iter().enumerate() { - debug_assert_eq!( - d.len(), - 1usize << ell_n, - "rhs_digits[{i}] length must match time domain" - ); - } - - // Degree bound: - // - chi(t): multilinear (degree 1) - // - has_lookup(t): multilinear (degree 1) - // - bitwise digit op: degree 6 (two degree-3 bit extractors multiplied) - // ⇒ total degree ≤ 1 + 1 + 6 = 8 - let inv2 = K::from_u64(2).inverse(); - let inv6 = K::from_u64(6).inverse(); - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - lhs_digits, - rhs_digits, - val, - op, - inv2, - inv6, - degree_bound: 8, - } - } -} - -impl RoundOracle for Rv32PackedBitwiseOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let val = self.val.singleton_value(); - - let mut out = K::ZERO; - for i in 0..16usize { - let a = self.lhs_digits[i].singleton_value(); - let b = self.rhs_digits[i].singleton_value(); - let digit = rv32_two_bit_digit_op(self.inv2, self.inv6, self.op, a, b); - out += digit * K::from_u64(1u64 << (2 * i)); - } - let expr = out - val; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - - let mut a0s: [K; 16] = [K::ZERO; 16]; - let mut a1s: [K; 16] = [K::ZERO; 16]; - for (j, d) in self.lhs_digits.iter().enumerate() { - a0s[j] = d.get(child0); - a1s[j] = d.get(child1); - } - let mut b0s: [K; 16] = [K::ZERO; 16]; - let mut b1s: [K; 16] = [K::ZERO; 16]; - for (j, d) in self.rhs_digits.iter().enumerate() { - b0s[j] = d.get(child0); - b1s[j] = d.get(child1); - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - - let val_x = interp(val0, val1, x); - - let mut out = K::ZERO; - for j in 0..16usize { - let a = interp(a0s[j], a1s[j], x); - let b = interp(b0s[j], b1s[j], x); - let digit = rv32_two_bit_digit_op(self.inv2, self.inv6, self.op, a, b); - out += digit * K::from_u64(1u64 << (2 * j)); - } - - let expr = out - val_x; - ys[i] += chi_x * gate_x * expr; - } - } - ys - } + let xm3 = x - K::from_u64(3); - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } + let x_xm1 = x * xm1; + let l1 = (x * xm2 * xm3) * inv2; + let l3 = (x_xm1 * xm2) * inv6; + let l2 = -(x_xm1 * xm3) * inv2; - fn degree_bound(&self) -> usize { - self.degree_bound - } + let bit0 = l1 + l3; + let bit1 = l2 + l3; + (bit0, bit1) +} - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; +fn rv32_two_bit_digit_op(inv2: K, inv6: K, op: Rv32PackedBitwiseOp2, a: K, b: K) -> K { + let (a0, a1) = rv32_two_bit_digit_bits(inv2, inv6, a); + let (b0, b1) = rv32_two_bit_digit_bits(inv2, inv6, b); + + let two = K::from_u64(2); + match op { + Rv32PackedBitwiseOp2::And => { + let r0 = a0 * b0; + let r1 = a1 * b1; + r0 + two * r1 } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - for d in self.lhs_digits.iter_mut() { - d.fold_round_in_place(r); + Rv32PackedBitwiseOp2::Andn => { + let r0 = a0 * (K::ONE - b0); + let r1 = a1 * (K::ONE - b1); + r0 + two * r1 } - for d in self.rhs_digits.iter_mut() { - d.fold_round_in_place(r); + Rv32PackedBitwiseOp2::Or => { + let r0 = a0 + b0 - a0 * b0; + let r1 = a1 + b1 - a1 * b1; + r0 + two * r1 } - self.val.fold_round_in_place(r); - self.bit_idx += 1; - } -} - -pub struct Rv32PackedAndOracleSparseTime { - core: Rv32PackedBitwiseOracleSparseTime, -} -impl Rv32PackedAndOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs_digits: Vec>, - rhs_digits: Vec>, - val: SparseIdxVec, - ) -> Self { - Self { - core: Rv32PackedBitwiseOracleSparseTime::new( - r_cycle, - has_lookup, - lhs_digits, - rhs_digits, - val, - Rv32PackedBitwiseOp2::And, - ), + Rv32PackedBitwiseOp2::Xor => { + let r0 = a0 + b0 - two * a0 * b0; + let r1 = a1 + b1 - two * a1 * b1; + r0 + two * r1 } } } -impl_round_oracle_via_core!(Rv32PackedAndOracleSparseTime); -pub struct Rv32PackedAndnOracleSparseTime { - core: Rv32PackedBitwiseOracleSparseTime, -} -impl Rv32PackedAndnOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs_digits: Vec>, - rhs_digits: Vec>, - val: SparseIdxVec, - ) -> Self { - Self { - core: Rv32PackedBitwiseOracleSparseTime::new( - r_cycle, - has_lookup, - lhs_digits, - rhs_digits, - val, - Rv32PackedBitwiseOp2::Andn, - ), - } - } +fn rv32_digit4_range_poly(x: K) -> K { + x * (x - K::ONE) * (x - K::from_u64(2)) * (x - K::from_u64(3)) } -impl_round_oracle_via_core!(Rv32PackedAndnOracleSparseTime); -pub struct Rv32PackedOrOracleSparseTime { - core: Rv32PackedBitwiseOracleSparseTime, -} -impl Rv32PackedOrOracleSparseTime { - pub fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - lhs_digits: Vec>, - rhs_digits: Vec>, - val: SparseIdxVec, - ) -> Self { - Self { - core: Rv32PackedBitwiseOracleSparseTime::new( - r_cycle, - has_lookup, - lhs_digits, - rhs_digits, - val, - Rv32PackedBitwiseOp2::Or, - ), - } +fn expr_rv32_packed_bitwise_adapter(cols: &[K; 2], bits: &[K], w: &[K; 34]) -> K { + let lhs = cols[0]; + let rhs = cols[1]; + debug_assert_eq!(bits.len(), 32); + + let mut lhs_recon = K::ZERO; + let mut rhs_recon = K::ZERO; + let mut range_sum = K::ZERO; + for i in 0..16usize { + let pow = K::from_u64(1u64 << (2 * i)); + let a = bits[i]; + let b = bits[16 + i]; + lhs_recon += a * pow; + rhs_recon += b * pow; + range_sum += w[2 + i] * rv32_digit4_range_poly(a); + range_sum += w[2 + 16 + i] * rv32_digit4_range_poly(b); } + + w[0] * (lhs - lhs_recon) + w[1] * (rhs - rhs_recon) + range_sum } -impl_round_oracle_via_core!(Rv32PackedOrOracleSparseTime); -pub struct Rv32PackedXorOracleSparseTime { - core: Rv32PackedBitwiseOracleSparseTime, +pub struct Rv32PackedBitwiseAdapterOracleSparseTime { + core: SparseColsBitsExprOracle<2, 34>, } -impl Rv32PackedXorOracleSparseTime { + +impl Rv32PackedBitwiseAdapterOracleSparseTime { pub fn new( r_cycle: &[K], has_lookup: SparseIdxVec, + lhs: SparseIdxVec, + rhs: SparseIdxVec, lhs_digits: Vec>, rhs_digits: Vec>, - val: SparseIdxVec, + weights: Vec, ) -> Self { + debug_assert_eq!(lhs_digits.len(), 16); + debug_assert_eq!(rhs_digits.len(), 16); + debug_assert_eq!(weights.len(), 34); + let expr_weights: [K; 34] = weights + .try_into() + .unwrap_or_else(|v: Vec| panic!("bitwise adapter weights length must be 34, got {}", v.len())); + let mut digits = lhs_digits; + digits.extend(rhs_digits); Self { - core: Rv32PackedBitwiseOracleSparseTime::new( + core: SparseColsBitsExprOracle::new( r_cycle, has_lookup, - lhs_digits, - rhs_digits, - val, - Rv32PackedBitwiseOp2::Xor, + [lhs, rhs], + digits, + expr_weights, + 6, + expr_rv32_packed_bitwise_adapter, ), } } } -impl_round_oracle_via_core!(Rv32PackedXorOracleSparseTime); - -/// Sparse Route A oracle for u32 bit-decomposition consistency: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (x(t) - Σ_i 2^i · bit_i(t)) -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct U32DecompOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - x: SparseIdxVec, - bits: Vec>, - weights: Vec, - degree_bound: usize, -} - -impl U32DecompOracleSparseTime { - pub fn new(r_cycle: &[K], has_lookup: SparseIdxVec, x: SparseIdxVec, bits: Vec>) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(x.len(), 1usize << ell_n); - debug_assert_eq!(bits.len(), 32); - for (i, b) in bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "bits[{i}] length must match time domain"); - } - let weights: Vec = (0..32).map(|i| K::from_u64(1u64 << i)).collect(); +impl_round_oracle_via_core!(Rv32PackedBitwiseAdapterOracleSparseTime); - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - x, - bits, - weights, - degree_bound: 3, - } +fn expr_rv32_packed_bitwise(cols: &[K; 1], bits: &[K], w: &[K; 2], op: Rv32PackedBitwiseOp2) -> K { + let val = cols[0]; + debug_assert_eq!(bits.len(), 32); + let mut out = K::ZERO; + for i in 0..16usize { + let a = bits[i]; + let b = bits[16 + i]; + let digit = rv32_two_bit_digit_op(w[0], w[1], op, a, b); + out += digit * K::from_u64(1u64 << (2 * i)); } + out - val } -impl RoundOracle for U32DecompOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let x = self.x.singleton_value(); - let mut sum = K::ZERO; - for (b, w) in self.bits.iter().zip(self.weights.iter()) { - sum += b.singleton_value() * *w; - } - let expr = x - sum; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let x0 = self.x.get(child0); - let x1 = self.x.get(child1); - - let mut expr0 = x0; - let mut expr1 = x1; - for (b_col, w) in self.bits.iter().zip(self.weights.iter()) { - let b0 = b_col.get(child0); - let b1 = b_col.get(child1); - expr0 -= b0 * *w; - expr1 -= b1 * *w; - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let expr_x = interp(expr0, expr1, x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.x.fold_round_in_place(r); - for b in self.bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } +fn expr_rv32_packed_and(cols: &[K; 1], bits: &[K], w: &[K; 2]) -> K { + expr_rv32_packed_bitwise(cols, bits, w, Rv32PackedBitwiseOp2::And) } -/// Sparse Route A oracle for u5 bit-decomposition consistency: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · (x(t) - Σ_i 2^i · bit_i(t)) -/// -/// Intended usage: set the claimed sum to 0 to enforce correctness. -pub struct U5DecompOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - x: SparseIdxVec, - bits: Vec>, - weights: Vec, - degree_bound: usize, +fn expr_rv32_packed_andn(cols: &[K; 1], bits: &[K], w: &[K; 2]) -> K { + expr_rv32_packed_bitwise(cols, bits, w, Rv32PackedBitwiseOp2::Andn) } -impl U5DecompOracleSparseTime { - pub fn new(r_cycle: &[K], has_lookup: SparseIdxVec, x: SparseIdxVec, bits: Vec>) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(x.len(), 1usize << ell_n); - debug_assert_eq!(bits.len(), 5); - for (i, b) in bits.iter().enumerate() { - debug_assert_eq!(b.len(), 1usize << ell_n, "bits[{i}] length must match time domain"); - } - let weights: Vec = (0..5).map(|i| K::from_u64(1u64 << i)).collect(); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - x, - bits, - weights, - degree_bound: 3, - } - } +fn expr_rv32_packed_or(cols: &[K; 1], bits: &[K], w: &[K; 2]) -> K { + expr_rv32_packed_bitwise(cols, bits, w, Rv32PackedBitwiseOp2::Or) } -impl RoundOracle for U5DecompOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let x = self.x.singleton_value(); - let mut sum = K::ZERO; - for (b, w) in self.bits.iter().zip(self.weights.iter()) { - sum += b.singleton_value() * *w; - } - let expr = x - sum; - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); - let half = self.has_lookup.len() / 2; - debug_assert!(pairs.iter().all(|&p| p < half)); - - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let x0 = self.x.get(child0); - let x1 = self.x.get(child1); - - let mut expr0 = x0; - let mut expr1 = x1; - for (b_col, w) in self.bits.iter().zip(self.weights.iter()) { - let b0 = b_col.get(child0); - let b1 = b_col.get(child1); - expr0 -= b0 * *w; - expr1 -= b1 * *w; - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); +fn expr_rv32_packed_xor(cols: &[K; 1], bits: &[K], w: &[K; 2]) -> K { + expr_rv32_packed_bitwise(cols, bits, w, Rv32PackedBitwiseOp2::Xor) +} - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let expr_x = interp(expr0, expr1, x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } +macro_rules! define_rv32_packed_bitwise_oracle { + ($name:ident, $expr:expr) => { + pub struct $name { + core: SparseColsBitsExprOracle<1, 2>, } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + lhs_digits: Vec>, + rhs_digits: Vec>, + val: SparseIdxVec, + ) -> Self { + debug_assert_eq!(lhs_digits.len(), 16); + debug_assert_eq!(rhs_digits.len(), 16); + let mut digits = lhs_digits; + digits.extend(rhs_digits); + let inv2 = K::from_u64(2).inverse(); + let inv6 = K::from_u64(6).inverse(); + Self { + core: SparseColsBitsExprOracle::new(r_cycle, has_lookup, [val], digits, [inv2, inv6], 8, $expr), + } + } + } + impl_round_oracle_via_core!($name); + }; +} - fn degree_bound(&self) -> usize { - self.degree_bound - } +define_rv32_packed_bitwise_oracle!(Rv32PackedAndOracleSparseTime, expr_rv32_packed_and); +define_rv32_packed_bitwise_oracle!(Rv32PackedAndnOracleSparseTime, expr_rv32_packed_andn); +define_rv32_packed_bitwise_oracle!(Rv32PackedOrOracleSparseTime, expr_rv32_packed_or); +define_rv32_packed_bitwise_oracle!(Rv32PackedXorOracleSparseTime, expr_rv32_packed_xor); - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; +macro_rules! define_u_decomp_oracle { + ($name:ident, $num_bits:expr) => { + pub struct $name { + core: SparseWeightedBitsExprOracle<1>, } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.x.fold_round_in_place(r); - for b in self.bits.iter_mut() { - b.fold_round_in_place(r); + + impl $name { + pub fn new( + r_cycle: &[K], + has_lookup: SparseIdxVec, + x: SparseIdxVec, + bits: Vec>, + ) -> Self { + debug_assert_eq!(bits.len(), $num_bits); + let weights: Vec = (0..$num_bits).map(|i| K::from_u64(1u64 << i)).collect(); + Self { + core: SparseWeightedBitsExprOracle::new(r_cycle, has_lookup, [x], bits, weights, 3, expr_u_decomp), + } + } } - self.bit_idx += 1; - } + impl_round_oracle_via_core!($name); + }; } -/// Zero oracle over the time hypercube (for placeholder claims). +define_u_decomp_oracle!(U32DecompOracleSparseTime, 32); + +define_u_decomp_oracle!(U5DecompOracleSparseTime, 5); + pub struct ZeroOracleSparseTime { rounds_remaining: usize, degree_bound: usize, @@ -5236,7 +2006,6 @@ impl RoundOracle for ZeroOracleSparseTime { } } -#[inline] fn interp(f0: K, f1: K, x: K) -> K { f0 + (f1 - f0) * x } @@ -5292,12 +2061,7 @@ fn build_inc_terms_at_r_addr( continue; } - let mut eq_addr = K::ONE; - for (b, col) in wa_bits.iter().enumerate() { - let bit = col.get(t); - eq_addr *= eq_bit_affine(bit, r_addr[b]); - } - + let eq_addr = eq_addr_at_time(wa_bits, r_addr, t); let inc_at_r_addr = has_w * inc_t * eq_addr; if inc_at_r_addr != K::ZERO { out.push((t, inc_at_r_addr)); @@ -5306,13 +2070,53 @@ fn build_inc_terms_at_r_addr( out } -// ============================================================================ -// Sparse-in-time Route A oracles (Track A, shared CPU bus) -// ============================================================================ +fn eq_addr_at_time(bit_cols: &[SparseIdxVec], r_addr: &[K], t: usize) -> K { + debug_assert_eq!(bit_cols.len(), r_addr.len()); + let mut eq_addr = K::ONE; + for (b, col) in bit_cols.iter().enumerate() { + eq_addr *= eq_bit_affine(col.get(t), r_addr[b]); + } + eq_addr +} + +fn eq_addr_singleton(bit_cols: &[SparseIdxVec], r_addr: &[K]) -> K { + debug_assert_eq!(bit_cols.len(), r_addr.len()); + let mut eq_addr = K::ONE; + for (b, col) in bit_cols.iter().enumerate() { + eq_addr *= eq_bit_affine(col.singleton_value(), r_addr[b]); + } + eq_addr +} + +fn accumulate_pair_with_eq_addr_over_points( + ys: &mut [K], + points: &[K], + bit_cols: &[SparseIdxVec], + r_addr: &[K], + child0: usize, + child1: usize, + mut coeff_at: F, +) +where + F: FnMut(K) -> K, +{ + debug_assert_eq!(bit_cols.len(), r_addr.len()); + let mut eq0s = Vec::with_capacity(bit_cols.len()); + let mut d_eqs = Vec::with_capacity(bit_cols.len()); + for (b, col) in bit_cols.iter().enumerate() { + let e0 = eq_bit_affine(col.get(child0), r_addr[b]); + eq0s.push(e0); + d_eqs.push(eq_bit_affine(col.get(child1), r_addr[b]) - e0); + } + for (i, &x) in points.iter().enumerate() { + let mut eq_addr = K::ONE; + for (e0, de) in eq0s.iter().zip(d_eqs.iter()) { + eq_addr *= *e0 + *de * x; + } + ys[i] += coeff_at(x) * eq_addr; + } +} -/// Per-lane Twist bus columns (sparse in time). -/// -/// A "lane" is an independent per-step access slot with its own read + write ports. #[derive(Clone, Debug)] pub struct TwistLaneSparseCols { pub ra_bits: Vec>, @@ -5324,8 +2128,6 @@ pub struct TwistLaneSparseCols { pub inc_at_write_addr: SparseIdxVec, } -/// Sparse Route A oracle for Shout adapter: -/// Σ_t χ_{r_cycle}(t) · has_lookup(t) · Π_b eq(addr_bits_b(t), r_addr_b) pub struct IndexAdapterOracleSparseTime { bit_idx: usize, r_cycle: Vec, @@ -5333,8 +2135,6 @@ pub struct IndexAdapterOracleSparseTime { has_lookup: SparseIdxVec, addr_bits: Vec>, r_addr: Vec, - degree_bound: usize, - challenges: Vec, } impl IndexAdapterOracleSparseTime { @@ -5348,20 +2148,11 @@ impl IndexAdapterOracleSparseTime { let ell_addr = addr_bits.len(); debug_assert_eq!(has_lookup.len(), 1usize << ell_n); debug_assert_eq!(r_addr.len(), ell_addr); - for (b, col) in addr_bits.iter().enumerate() { - debug_assert_eq!( - col.len(), - 1usize << ell_n, - "addr_bits[{b}] length must match time domain" - ); - } + assert_cols_match_time(&addr_bits, 1usize << ell_n); let mut claim = K::ZERO; for &(t, gate) in has_lookup.entries() { - let mut eq_addr = K::ONE; - for (b, col) in addr_bits.iter().enumerate() { - eq_addr *= eq_bit_affine(col.get(t), r_addr[b]); - } + let eq_addr = eq_addr_at_time(&addr_bits, r_addr, t); claim += chi_at_bool_index(r_cycle, t) * gate * eq_addr; } @@ -5373,8 +2164,6 @@ impl IndexAdapterOracleSparseTime { has_lookup, addr_bits, r_addr: r_addr.to_vec(), - degree_bound: 2 + ell_addr, - challenges: Vec::with_capacity(ell_n), }, claim, ) @@ -5384,17 +2173,13 @@ impl IndexAdapterOracleSparseTime { impl RoundOracle for IndexAdapterOracleSparseTime { fn evals_at(&mut self, points: &[K]) -> Vec { if self.has_lookup.len() == 1 { - let mut eq_addr = K::ONE; - for (b, col) in self.addr_bits.iter().enumerate() { - eq_addr *= eq_bit_affine(col.singleton_value(), self.r_addr[b]); - } + let eq_addr = eq_addr_singleton(&self.addr_bits, &self.r_addr); let v = self.prefix_eq * self.has_lookup.singleton_value() * eq_addr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_lookup.entries()); let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -5405,27 +2190,10 @@ impl RoundOracle for IndexAdapterOracleSparseTime { } let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - // Per-bit eq factors at the children (after any time folding). - let mut eq0s: Vec = Vec::with_capacity(self.addr_bits.len()); - let mut d_eqs: Vec = Vec::with_capacity(self.addr_bits.len()); - for (b, col) in self.addr_bits.iter().enumerate() { - let e0 = eq_bit_affine(col.get(child0), self.r_addr[b]); - let e1 = eq_bit_affine(col.get(child1), self.r_addr[b]); - eq0s.push(e0); - d_eqs.push(e1 - e0); - } - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - let gate_x = interp(gate0, gate1, x); - let mut prod = chi_x * gate_x; - for (e0, de) in eq0s.iter().zip(d_eqs.iter()) { - prod *= *e0 + *de * x; - } - ys[i] += prod; - } - } + accumulate_pair_with_eq_addr_over_points(&mut ys, points, &self.addr_bits, &self.r_addr, child0, child1, |x| { + interp(chi0, chi1, x) * interp(gate0, gate1, x) + }); + }); ys } @@ -5434,7 +2202,7 @@ impl RoundOracle for IndexAdapterOracleSparseTime { } fn degree_bound(&self) -> usize { - self.degree_bound + 2 + self.r_addr.len() } fn fold(&mut self, r: K) { @@ -5446,17 +2214,10 @@ impl RoundOracle for IndexAdapterOracleSparseTime { for col in self.addr_bits.iter_mut() { col.fold_round_in_place(r); } - self.challenges.push(r); self.bit_idx += 1; } } -/// Sparse Route A oracle for χ_{r_cycle}-weighted *aggregated* bitness: -/// Σ_t χ_{r_cycle}(t) · ( Σ_i w_i · col_i(t) · (col_i(t) - 1) ) -/// -/// This reduces O(#bit-columns) separate degree-3 sumchecks to a single degree-3 sumcheck. -/// The weights `w_i` MUST be derived deterministically from transcript-known data (e.g. r_cycle), -/// so prover/verifier agree on the same polynomial. pub struct LazyWeightedBitnessOracleSparseTime { bit_idx: usize, r_cycle: Vec, @@ -5464,7 +2225,6 @@ pub struct LazyWeightedBitnessOracleSparseTime { cols: Vec>, weights: Vec, degree_bound: usize, - challenges: Vec, } impl LazyWeightedBitnessOracleSparseTime { @@ -5481,7 +2241,6 @@ impl LazyWeightedBitnessOracleSparseTime { cols, weights, degree_bound: 3, - challenges: Vec::with_capacity(ell_n), } } } @@ -5502,7 +2261,6 @@ impl RoundOracle for LazyWeightedBitnessOracleSparseTime { return vec![v; points.len()]; } - // Union of parent indices whose children contain any nonzero across the aggregated columns. let mut pairs: Vec = Vec::new(); for col in &self.cols { for &(idx, _v) in col.entries() { @@ -5555,276 +2313,70 @@ impl RoundOracle for LazyWeightedBitnessOracleSparseTime { return; } self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - for col in self.cols.iter_mut() { - col.fold_round_in_place(r); - } - self.challenges.push(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for Twist read-check (time rounds only). -pub struct TwistReadCheckOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - degree_bound: usize, - - r_addr: Vec, - ra_bits: Vec>, - has_read: SparseIdxVec, - rv: SparseIdxVec, - - init_at_r_addr: K, - inc_terms_at_r_addr: Vec<(usize, K)>, - - t_child0: Vec, - t_child1: Vec, - challenges: Vec, -} - -impl TwistReadCheckOracleSparseTime { - #[allow(clippy::too_many_arguments)] - pub fn new( - r_cycle: &[K], - has_read: SparseIdxVec, - rv: SparseIdxVec, - ra_bits: Vec>, - // Write stream (for Val_pre). - has_write: SparseIdxVec, - inc_at_write_addr: SparseIdxVec, - wa_bits: Vec>, - r_addr: &[K], - init_at_r_addr: K, - ) -> Self { - let ell_n = r_cycle.len(); - let ell_addr = r_addr.len(); - debug_assert_eq!(has_read.len(), 1usize << ell_n); - debug_assert_eq!(rv.len(), 1usize << ell_n); - debug_assert_eq!(has_write.len(), 1usize << ell_n); - debug_assert_eq!(inc_at_write_addr.len(), 1usize << ell_n); - debug_assert_eq!(ra_bits.len(), ell_addr); - debug_assert_eq!(wa_bits.len(), ell_addr); - - let inc_terms_at_r_addr = build_inc_terms_at_r_addr(&wa_bits, &has_write, &inc_at_write_addr, r_addr); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - degree_bound: 3 + ell_addr, - r_addr: r_addr.to_vec(), - ra_bits, - has_read, - rv, - init_at_r_addr, - inc_terms_at_r_addr, - t_child0: vec![K::ZERO; ell_n], - t_child1: vec![K::ZERO; ell_n], - challenges: Vec::with_capacity(ell_n), - } - } - - pub fn new_with_inc_terms( - r_cycle: &[K], - has_read: SparseIdxVec, - rv: SparseIdxVec, - ra_bits: Vec>, - r_addr: &[K], - init_at_r_addr: K, - inc_terms_at_r_addr: Vec<(usize, K)>, - ) -> Self { - let ell_n = r_cycle.len(); - let ell_addr = r_addr.len(); - debug_assert_eq!(has_read.len(), 1usize << ell_n); - debug_assert_eq!(rv.len(), 1usize << ell_n); - debug_assert_eq!(ra_bits.len(), ell_addr); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - degree_bound: 3 + ell_addr, - r_addr: r_addr.to_vec(), - ra_bits, - has_read, - rv, - init_at_r_addr, - inc_terms_at_r_addr, - t_child0: vec![K::ZERO; ell_n], - t_child1: vec![K::ZERO; ell_n], - challenges: Vec::with_capacity(ell_n), - } - } -} - -impl RoundOracle for TwistReadCheckOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_read.len() == 1 { - let mut eq_addr = K::ONE; - for (b, col) in self.ra_bits.iter().enumerate() { - eq_addr *= eq_bit_affine(col.singleton_value(), self.r_addr[b]); - } - let t_point = self.challenges.as_slice(); - let val_pre = val_pre_from_inc_terms(self.init_at_r_addr, &self.inc_terms_at_r_addr, t_point); - let diff = val_pre - self.rv.singleton_value(); - let v = self.prefix_eq * self.has_read.singleton_value() * diff * eq_addr; - return vec![v; points.len()]; - } - - let pairs = gather_pairs_from_sparse(self.has_read.entries()); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_read.get(child0); - let gate1 = self.has_read.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - fill_time_point(&mut self.t_child0, &self.challenges, self.bit_idx, K::ZERO, pair); - fill_time_point(&mut self.t_child1, &self.challenges, self.bit_idx, K::ONE, pair); - let val_pre0 = val_pre_from_inc_terms(self.init_at_r_addr, &self.inc_terms_at_r_addr, &self.t_child0); - let val_pre1 = val_pre_from_inc_terms(self.init_at_r_addr, &self.inc_terms_at_r_addr, &self.t_child1); - - let diff0 = val_pre0 - self.rv.get(child0); - let diff1 = val_pre1 - self.rv.get(child1); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - let mut eq0s: Vec = Vec::with_capacity(self.ra_bits.len()); - let mut d_eqs: Vec = Vec::with_capacity(self.ra_bits.len()); - for (b, col) in self.ra_bits.iter().enumerate() { - let e0 = eq_bit_affine(col.get(child0), self.r_addr[b]); - let e1 = eq_bit_affine(col.get(child1), self.r_addr[b]); - eq0s.push(e0); - d_eqs.push(e1 - e0); - } - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - let gate_x = interp(gate0, gate1, x); - let diff_x = interp(diff0, diff1, x); - let mut prod = chi_x * gate_x * diff_x; - for (e0, de) in eq0s.iter().zip(d_eqs.iter()) { - prod *= *e0 + *de * x; - } - ys[i] += prod; - } - } - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_read.fold_round_in_place(r); - self.rv.fold_round_in_place(r); - for col in self.ra_bits.iter_mut() { - col.fold_round_in_place(r); - } - self.challenges.push(r); - self.bit_idx += 1; - } -} - -/// Sparse Route A oracle for Twist write-check (time rounds only). -pub struct TwistWriteCheckOracleSparseTime { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - degree_bound: usize, - - r_addr: Vec, - wa_bits: Vec>, - has_write: SparseIdxVec, - wv: SparseIdxVec, - inc_at_write_addr: SparseIdxVec, - - init_at_r_addr: K, - inc_terms_at_r_addr: Vec<(usize, K)>, - - t_child0: Vec, - t_child1: Vec, - challenges: Vec, -} - -impl TwistWriteCheckOracleSparseTime { - #[allow(clippy::too_many_arguments)] - pub fn new( - r_cycle: &[K], - has_write: SparseIdxVec, - wv: SparseIdxVec, - inc_at_write_addr: SparseIdxVec, - wa_bits: Vec>, - r_addr: &[K], - init_at_r_addr: K, - ) -> Self { - let ell_n = r_cycle.len(); - let ell_addr = r_addr.len(); - debug_assert_eq!(has_write.len(), 1usize << ell_n); - debug_assert_eq!(wv.len(), 1usize << ell_n); - debug_assert_eq!(inc_at_write_addr.len(), 1usize << ell_n); - debug_assert_eq!(wa_bits.len(), ell_addr); - - let inc_terms_at_r_addr = build_inc_terms_at_r_addr(&wa_bits, &has_write, &inc_at_write_addr, r_addr); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - degree_bound: 3 + ell_addr, - r_addr: r_addr.to_vec(), - wa_bits, - has_write, - wv, - inc_at_write_addr, - init_at_r_addr, - inc_terms_at_r_addr, - t_child0: vec![K::ZERO; ell_n], - t_child1: vec![K::ZERO; ell_n], - challenges: Vec::with_capacity(ell_n), + for col in self.cols.iter_mut() { + col.fold_round_in_place(r); } + self.bit_idx += 1; } +} - pub fn new_with_inc_terms( +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum TwistTimeCheckKind { + Read, + Write, +} + +struct TwistTimeCheckOracleSparseTimeCore { + kind: TwistTimeCheckKind, + bit_idx: usize, + r_cycle: Vec, + prefix_eq: K, + + r_addr: Vec, + addr_bits: Vec>, + gate: SparseIdxVec, + value: SparseIdxVec, + inc_at_write_addr: Option>, + + init_at_r_addr: K, + inc_terms_at_r_addr: Vec<(usize, K)>, + + t_child0: Vec, + t_child1: Vec, + challenges: Vec, +} + +impl TwistTimeCheckOracleSparseTimeCore { + #[allow(clippy::too_many_arguments)] + fn new( + kind: TwistTimeCheckKind, r_cycle: &[K], - has_write: SparseIdxVec, - wv: SparseIdxVec, - inc_at_write_addr: SparseIdxVec, - wa_bits: Vec>, + gate: SparseIdxVec, + value: SparseIdxVec, + addr_bits: Vec>, + inc_at_write_addr: Option>, r_addr: &[K], init_at_r_addr: K, inc_terms_at_r_addr: Vec<(usize, K)>, ) -> Self { let ell_n = r_cycle.len(); let ell_addr = r_addr.len(); - debug_assert_eq!(has_write.len(), 1usize << ell_n); - debug_assert_eq!(wv.len(), 1usize << ell_n); - debug_assert_eq!(inc_at_write_addr.len(), 1usize << ell_n); - debug_assert_eq!(wa_bits.len(), ell_addr); + debug_assert_eq!(gate.len(), 1usize << ell_n); + debug_assert_eq!(value.len(), 1usize << ell_n); + if let Some(inc) = inc_at_write_addr.as_ref() { + debug_assert_eq!(inc.len(), 1usize << ell_n); + } + debug_assert_eq!(addr_bits.len(), ell_addr); Self { + kind, bit_idx: 0, r_cycle: r_cycle.to_vec(), prefix_eq: K::ONE, - degree_bound: 3 + ell_addr, r_addr: r_addr.to_vec(), - wa_bits, - has_write, - wv, + addr_bits, + gate, + value, inc_at_write_addr, init_at_r_addr, inc_terms_at_r_addr, @@ -5835,28 +2387,34 @@ impl TwistWriteCheckOracleSparseTime { } } -impl RoundOracle for TwistWriteCheckOracleSparseTime { +impl RoundOracle for TwistTimeCheckOracleSparseTimeCore { fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_write.len() == 1 { - let mut eq_addr = K::ONE; - for (b, col) in self.wa_bits.iter().enumerate() { - eq_addr *= eq_bit_affine(col.singleton_value(), self.r_addr[b]); - } + let is_write = matches!(self.kind, TwistTimeCheckKind::Write); + if self.gate.len() == 1 { + let eq_addr = eq_addr_singleton(&self.addr_bits, &self.r_addr); let t_point = self.challenges.as_slice(); let val_pre = val_pre_from_inc_terms(self.init_at_r_addr, &self.inc_terms_at_r_addr, t_point); - let delta = self.wv.singleton_value() - val_pre - self.inc_at_write_addr.singleton_value(); - let v = self.prefix_eq * self.has_write.singleton_value() * delta * eq_addr; + let inc = self + .inc_at_write_addr + .as_ref() + .map(|c| c.singleton_value()) + .unwrap_or(K::ZERO); + let term = if is_write { + self.value.singleton_value() - val_pre - inc + } else { + val_pre - self.value.singleton_value() + }; + let v = self.prefix_eq * self.gate.singleton_value() * term * eq_addr; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_write.entries()); let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.gate.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; - let gate0 = self.has_write.get(child0); - let gate1 = self.has_write.get(child1); + let gate0 = self.gate.get(child0); + let gate1 = self.gate.get(child1); if gate0 == K::ZERO && gate1 == K::ZERO { continue; } @@ -5866,31 +2424,34 @@ impl RoundOracle for TwistWriteCheckOracleSparseTime { let val_pre0 = val_pre_from_inc_terms(self.init_at_r_addr, &self.inc_terms_at_r_addr, &self.t_child0); let val_pre1 = val_pre_from_inc_terms(self.init_at_r_addr, &self.inc_terms_at_r_addr, &self.t_child1); - let delta0 = self.wv.get(child0) - val_pre0 - self.inc_at_write_addr.get(child0); - let delta1 = self.wv.get(child1) - val_pre1 - self.inc_at_write_addr.get(child1); + let value0 = self.value.get(child0); + let value1 = self.value.get(child1); + let inc0 = self + .inc_at_write_addr + .as_ref() + .map(|c| c.get(child0)) + .unwrap_or(K::ZERO); + let inc1 = self + .inc_at_write_addr + .as_ref() + .map(|c| c.get(child1)) + .unwrap_or(K::ZERO); + let term0 = if is_write { + value0 - val_pre0 - inc0 + } else { + val_pre0 - value0 + }; + let term1 = if is_write { + value1 - val_pre1 - inc1 + } else { + val_pre1 - value1 + }; let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - let mut eq0s: Vec = Vec::with_capacity(self.wa_bits.len()); - let mut d_eqs: Vec = Vec::with_capacity(self.wa_bits.len()); - for (b, col) in self.wa_bits.iter().enumerate() { - let e0 = eq_bit_affine(col.get(child0), self.r_addr[b]); - let e1 = eq_bit_affine(col.get(child1), self.r_addr[b]); - eq0s.push(e0); - d_eqs.push(e1 - e0); - } - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - let gate_x = interp(gate0, gate1, x); - let delta_x = interp(delta0, delta1, x); - let mut prod = chi_x * gate_x * delta_x; - for (e0, de) in eq0s.iter().zip(d_eqs.iter()) { - prod *= *e0 + *de * x; - } - ys[i] += prod; - } - } + accumulate_pair_with_eq_addr_over_points(&mut ys, points, &self.addr_bits, &self.r_addr, child0, child1, |x| { + interp(chi0, chi1, x) * interp(gate0, gate1, x) * interp(term0, term1, x) + }); + }); ys } @@ -5899,7 +2460,7 @@ impl RoundOracle for TwistWriteCheckOracleSparseTime { } fn degree_bound(&self) -> usize { - self.degree_bound + 3 + self.r_addr.len() } fn fold(&mut self, r: K) { @@ -5907,10 +2468,12 @@ impl RoundOracle for TwistWriteCheckOracleSparseTime { return; } self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_write.fold_round_in_place(r); - self.wv.fold_round_in_place(r); - self.inc_at_write_addr.fold_round_in_place(r); - for col in self.wa_bits.iter_mut() { + self.gate.fold_round_in_place(r); + self.value.fold_round_in_place(r); + if let Some(inc) = self.inc_at_write_addr.as_mut() { + inc.fold_round_in_place(r); + } + for col in self.addr_bits.iter_mut() { col.fold_round_in_place(r); } self.challenges.push(r); @@ -5918,215 +2481,244 @@ impl RoundOracle for TwistWriteCheckOracleSparseTime { } } -/// Sparse time-domain val-evaluation oracle: -/// Σ_t has_write(t) · inc(t) · LT(t, r_time) · Π_b eq(wa_bit_b(t), r_addr_b) -pub struct TwistValEvalOracleSparseTime { - bit_idx: usize, - degree_bound: usize, - - r_time: Vec, - r_addr: Vec, - - wa_bits: Vec>, - has_write: SparseIdxVec, - inc_at_write_addr: SparseIdxVec, - - t_child0: Vec, - t_child1: Vec, - challenges: Vec, +pub struct TwistReadCheckOracleSparseTime { + core: TwistTimeCheckOracleSparseTimeCore, } -impl TwistValEvalOracleSparseTime { +impl TwistReadCheckOracleSparseTime { + #[allow(clippy::too_many_arguments)] pub fn new( - wa_bits: Vec>, + r_cycle: &[K], + has_read: SparseIdxVec, + rv: SparseIdxVec, + ra_bits: Vec>, has_write: SparseIdxVec, inc_at_write_addr: SparseIdxVec, + wa_bits: Vec>, r_addr: &[K], - r_time: &[K], - ) -> (Self, K) { - let ell_n = r_time.len(); + init_at_r_addr: K, + ) -> Self { let ell_addr = r_addr.len(); - debug_assert_eq!(has_write.len(), 1usize << ell_n); - debug_assert_eq!(inc_at_write_addr.len(), 1usize << ell_n); + debug_assert_eq!(ra_bits.len(), ell_addr); debug_assert_eq!(wa_bits.len(), ell_addr); - for (b, col) in wa_bits.iter().enumerate() { - debug_assert_eq!(col.len(), 1usize << ell_n, "wa_bits[{b}] length must match time domain"); - } - let mut claim = K::ZERO; - for &(t, gate) in has_write.entries() { - let inc_t = inc_at_write_addr.get(t); - if gate == K::ZERO || inc_t == K::ZERO { - continue; - } - let lt_t = lt_eval_at_bool_index(t, r_time); - let mut eq_addr = K::ONE; - for (b, col) in wa_bits.iter().enumerate() { - eq_addr *= eq_bit_affine(col.get(t), r_addr[b]); - } - claim += gate * inc_t * lt_t * eq_addr; - } + let inc_terms_at_r_addr = build_inc_terms_at_r_addr(&wa_bits, &has_write, &inc_at_write_addr, r_addr); - ( - Self { - bit_idx: 0, - degree_bound: 3 + ell_addr, - r_time: r_time.to_vec(), - r_addr: r_addr.to_vec(), - wa_bits, - has_write, - inc_at_write_addr, - t_child0: vec![K::ZERO; ell_n], - t_child1: vec![K::ZERO; ell_n], - challenges: Vec::with_capacity(ell_n), - }, - claim, - ) + Self::new_with_inc_terms(r_cycle, has_read, rv, ra_bits, r_addr, init_at_r_addr, inc_terms_at_r_addr) } -} -impl RoundOracle for TwistValEvalOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_write.len() == 1 { - let mut eq_addr = K::ONE; - for (b, col) in self.wa_bits.iter().enumerate() { - eq_addr *= eq_bit_affine(col.singleton_value(), self.r_addr[b]); - } - let lt = lt_eval(&self.challenges, &self.r_time); - let v = self.has_write.singleton_value() * self.inc_at_write_addr.singleton_value() * eq_addr * lt; - return vec![v; points.len()]; + pub fn new_with_inc_terms( + r_cycle: &[K], + has_read: SparseIdxVec, + rv: SparseIdxVec, + ra_bits: Vec>, + r_addr: &[K], + init_at_r_addr: K, + inc_terms_at_r_addr: Vec<(usize, K)>, + ) -> Self { + Self { + core: TwistTimeCheckOracleSparseTimeCore::new( + TwistTimeCheckKind::Read, + r_cycle, + has_read, + rv, + ra_bits, + None, + r_addr, + init_at_r_addr, + inc_terms_at_r_addr, + ), } + } +} +impl_round_oracle_via_core!(TwistReadCheckOracleSparseTime); - let pairs = gather_pairs_from_sparse(self.has_write.entries()); - let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_write.get(child0); - let gate1 = self.has_write.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - let inc0 = self.inc_at_write_addr.get(child0); - let inc1 = self.inc_at_write_addr.get(child1); - - fill_time_point(&mut self.t_child0, &self.challenges, self.bit_idx, K::ZERO, pair); - fill_time_point(&mut self.t_child1, &self.challenges, self.bit_idx, K::ONE, pair); - let lt0 = lt_eval(&self.t_child0, &self.r_time); - let lt1 = lt_eval(&self.t_child1, &self.r_time); - - let mut eq0s: Vec = Vec::with_capacity(self.wa_bits.len()); - let mut d_eqs: Vec = Vec::with_capacity(self.wa_bits.len()); - for (b, col) in self.wa_bits.iter().enumerate() { - let e0 = eq_bit_affine(col.get(child0), self.r_addr[b]); - let e1 = eq_bit_affine(col.get(child1), self.r_addr[b]); - eq0s.push(e0); - d_eqs.push(e1 - e0); - } +pub struct TwistWriteCheckOracleSparseTime { + core: TwistTimeCheckOracleSparseTimeCore, +} - for (i, &x) in points.iter().enumerate() { - let gate_x = interp(gate0, gate1, x); - let inc_x = interp(inc0, inc1, x); - let lt_x = interp(lt0, lt1, x); - let mut prod = gate_x * inc_x * lt_x; - for (e0, de) in eq0s.iter().zip(d_eqs.iter()) { - prod *= *e0 + *de * x; - } - ys[i] += prod; - } - } - ys - } +impl TwistWriteCheckOracleSparseTime { + #[allow(clippy::too_many_arguments)] + pub fn new( + r_cycle: &[K], + has_write: SparseIdxVec, + wv: SparseIdxVec, + inc_at_write_addr: SparseIdxVec, + wa_bits: Vec>, + r_addr: &[K], + init_at_r_addr: K, + ) -> Self { + let ell_addr = r_addr.len(); + debug_assert_eq!(wa_bits.len(), ell_addr); - fn num_rounds(&self) -> usize { - self.r_time.len().saturating_sub(self.bit_idx) - } + let inc_terms_at_r_addr = build_inc_terms_at_r_addr(&wa_bits, &has_write, &inc_at_write_addr, r_addr); - fn degree_bound(&self) -> usize { - self.degree_bound + Self::new_with_inc_terms( + r_cycle, + has_write, + wv, + inc_at_write_addr, + wa_bits, + r_addr, + init_at_r_addr, + inc_terms_at_r_addr, + ) } - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.has_write.fold_round_in_place(r); - self.inc_at_write_addr.fold_round_in_place(r); - for col in self.wa_bits.iter_mut() { - col.fold_round_in_place(r); + pub fn new_with_inc_terms( + r_cycle: &[K], + has_write: SparseIdxVec, + wv: SparseIdxVec, + inc_at_write_addr: SparseIdxVec, + wa_bits: Vec>, + r_addr: &[K], + init_at_r_addr: K, + inc_terms_at_r_addr: Vec<(usize, K)>, + ) -> Self { + Self { + core: TwistTimeCheckOracleSparseTimeCore::new( + TwistTimeCheckKind::Write, + r_cycle, + has_write, + wv, + wa_bits, + Some(inc_at_write_addr), + r_addr, + init_at_r_addr, + inc_terms_at_r_addr, + ), } - self.challenges.push(r); - self.bit_idx += 1; } } +impl_round_oracle_via_core!(TwistWriteCheckOracleSparseTime); -/// Sparse time-domain total-increment oracle: -/// Σ_t has_write(t) · inc(t) · Π_b eq(wa_bit_b(t), r_addr_b) -pub struct TwistTotalIncOracleSparseTime { - degree_bound: usize, +enum TwistWriteEqAddrMode { + TotalInc, + ValEval { + r_time: Vec, + t_child0: Vec, + t_child1: Vec, + }, +} +struct TwistWriteEqAddrOracleSparseTimeCore { r_addr: Vec, wa_bits: Vec>, has_write: SparseIdxVec, inc_at_write_addr: SparseIdxVec, + mode: TwistWriteEqAddrMode, + challenges: Vec, } -impl TwistTotalIncOracleSparseTime { - pub fn new( +fn claim_write_eq_addr( + wa_bits: &[SparseIdxVec], + has_write: &SparseIdxVec, + inc_at_write_addr: &SparseIdxVec, + r_addr: &[K], + mut time_weight: F, +) -> K +where + F: FnMut(usize) -> K, +{ + let mut claim = K::ZERO; + for &(t, gate) in has_write.entries() { + let inc_t = inc_at_write_addr.get(t); + if gate == K::ZERO || inc_t == K::ZERO { + continue; + } + let eq_addr = eq_addr_at_time(wa_bits, r_addr, t); + claim += gate * inc_t * eq_addr * time_weight(t); + } + claim +} + +impl TwistWriteEqAddrOracleSparseTimeCore { + fn new_with_mode( wa_bits: Vec>, has_write: SparseIdxVec, inc_at_write_addr: SparseIdxVec, r_addr: &[K], - ) -> (Self, K) { + mode: TwistWriteEqAddrMode, + mut time_weight: F, + ) -> (Self, K) + where + F: FnMut(usize) -> K, + { let ell_n = log2_pow2(has_write.len()); let ell_addr = r_addr.len(); - debug_assert_eq!(inc_at_write_addr.len(), 1usize << ell_n); + let pow2_n = 1usize << ell_n; + debug_assert_eq!(inc_at_write_addr.len(), pow2_n); debug_assert_eq!(wa_bits.len(), ell_addr); - for (b, col) in wa_bits.iter().enumerate() { - debug_assert_eq!(col.len(), 1usize << ell_n, "wa_bits[{b}] length must match time domain"); - } - - let mut claim = K::ZERO; - for &(t, gate) in has_write.entries() { - let inc_t = inc_at_write_addr.get(t); - if gate == K::ZERO || inc_t == K::ZERO { - continue; - } - let mut eq_addr = K::ONE; - for (b, col) in wa_bits.iter().enumerate() { - eq_addr *= eq_bit_affine(col.get(t), r_addr[b]); - } - claim += gate * inc_t * eq_addr; - } + assert_cols_match_time(&wa_bits, pow2_n); + let claim = claim_write_eq_addr(&wa_bits, &has_write, &inc_at_write_addr, r_addr, |t| time_weight(t)); ( Self { - degree_bound: 2 + ell_addr, r_addr: r_addr.to_vec(), wa_bits, has_write, inc_at_write_addr, + mode, + challenges: Vec::with_capacity(ell_n), }, claim, ) } + + fn new_val( + wa_bits: Vec>, + has_write: SparseIdxVec, + inc_at_write_addr: SparseIdxVec, + r_addr: &[K], + r_time: &[K], + ) -> (Self, K) { + let ell_n = r_time.len(); + debug_assert_eq!(has_write.len(), 1usize << ell_n); + Self::new_with_mode( + wa_bits, + has_write, + inc_at_write_addr, + r_addr, + TwistWriteEqAddrMode::ValEval { + r_time: r_time.to_vec(), + t_child0: vec![K::ZERO; ell_n], + t_child1: vec![K::ZERO; ell_n], + }, + |t| lt_eval_at_bool_index(t, r_time), + ) + } + + fn new_total( + wa_bits: Vec>, + has_write: SparseIdxVec, + inc_at_write_addr: SparseIdxVec, + r_addr: &[K], + ) -> (Self, K) { + Self::new_with_mode( + wa_bits, + has_write, + inc_at_write_addr, + r_addr, + TwistWriteEqAddrMode::TotalInc, + |_t| K::ONE, + ) + } } -impl RoundOracle for TwistTotalIncOracleSparseTime { +impl RoundOracle for TwistWriteEqAddrOracleSparseTimeCore { fn evals_at(&mut self, points: &[K]) -> Vec { if self.has_write.len() == 1 { - let mut eq_addr = K::ONE; - for (b, col) in self.wa_bits.iter().enumerate() { - eq_addr *= eq_bit_affine(col.singleton_value(), self.r_addr[b]); - } - let v = self.has_write.singleton_value() * self.inc_at_write_addr.singleton_value() * eq_addr; + let eq_addr = eq_addr_singleton(&self.wa_bits, &self.r_addr); + let lt = match &self.mode { + TwistWriteEqAddrMode::TotalInc => K::ONE, + TwistWriteEqAddrMode::ValEval { r_time, .. } => lt_eval(&self.challenges, r_time), + }; + let v = self.has_write.singleton_value() * self.inc_at_write_addr.singleton_value() * eq_addr * lt; return vec![v; points.len()]; } - let pairs = gather_pairs_from_sparse(self.has_write.entries()); let mut ys = vec![K::ZERO; points.len()]; - for &pair in pairs.iter() { + for_each_sparse_parent_pair!(self.has_write.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -6138,25 +2730,24 @@ impl RoundOracle for TwistTotalIncOracleSparseTime { let inc0 = self.inc_at_write_addr.get(child0); let inc1 = self.inc_at_write_addr.get(child1); - let mut eq0s: Vec = Vec::with_capacity(self.wa_bits.len()); - let mut d_eqs: Vec = Vec::with_capacity(self.wa_bits.len()); - for (b, col) in self.wa_bits.iter().enumerate() { - let e0 = eq_bit_affine(col.get(child0), self.r_addr[b]); - let e1 = eq_bit_affine(col.get(child1), self.r_addr[b]); - eq0s.push(e0); - d_eqs.push(e1 - e0); - } - - for (i, &x) in points.iter().enumerate() { - let gate_x = interp(gate0, gate1, x); - let inc_x = interp(inc0, inc1, x); - let mut prod = gate_x * inc_x; - for (e0, de) in eq0s.iter().zip(d_eqs.iter()) { - prod *= *e0 + *de * x; - } - ys[i] += prod; - } - } + let (lt0, lt1) = match &mut self.mode { + TwistWriteEqAddrMode::TotalInc => (K::ONE, K::ONE), + TwistWriteEqAddrMode::ValEval { + r_time, + t_child0, + t_child1, + } => { + let bit_idx = self.challenges.len(); + fill_time_point(t_child0, &self.challenges, bit_idx, K::ZERO, pair); + fill_time_point(t_child1, &self.challenges, bit_idx, K::ONE, pair); + (lt_eval(t_child0, r_time), lt_eval(t_child1, r_time)) + } + }; + + accumulate_pair_with_eq_addr_over_points(&mut ys, points, &self.wa_bits, &self.r_addr, child0, child1, |x| { + interp(gate0, gate1, x) * interp(inc0, inc1, x) * interp(lt0, lt1, x) + }); + }); ys } @@ -6165,11 +2756,14 @@ impl RoundOracle for TwistTotalIncOracleSparseTime { } fn degree_bound(&self) -> usize { - self.degree_bound + match self.mode { + TwistWriteEqAddrMode::TotalInc => 2 + self.r_addr.len(), + TwistWriteEqAddrMode::ValEval { .. } => 3 + self.r_addr.len(), + } } fn fold(&mut self, r: K) { - if self.has_write.len() == 1 { + if self.num_rounds() == 0 { return; } self.has_write.fold_round_in_place(r); @@ -6177,647 +2771,156 @@ impl RoundOracle for TwistTotalIncOracleSparseTime { for col in self.wa_bits.iter_mut() { col.fold_round_in_place(r); } + self.challenges.push(r); } } -// ============================================================================ -// Val-Evaluation Oracle (SPARSE version) -// ============================================================================ - -// ============================================================================ -// Address-Lane Oracles (addr rounds first, time summed) -// ============================================================================ -// -// These are used by the "Phase 2" Route A integration to avoid materializing a -// time×addr table when Twist must share `r_time` with CCS: -// 1) Run an address-lane sum-check first (ell_addr rounds) to bind `r_addr` -// and produce the *time-lane claimed sums* (one per check). -// 2) Run a time-lane sum-check (ell_n rounds) at the fixed `r_addr`. -// -// Concretely, these oracles implement the address-lane prefix of the same -// read/write check polynomials as the 2D oracles, but with the time variables -// summed internally. This keeps address rounds efficient and avoids allocating -// `pow2_time * pow2_addr`. -// -// Degree in each address variable is ≤ 2: -// - `Val_pre(addr, t)` is multilinear in `addr` (degree 1 per bit), -// - `eq(addr, bits(t))` is multilinear in `addr`, -// so their product has degree ≤ 2 per address bit. -// -// Variable order is little-endian address bits (bit 0 first), matching the -// rest of the module. - -fn update_prefix_weights_in_place(weights: &mut [K], addrs: &[usize], bit_idx: usize, r: K) { - let r0 = K::ONE - r; - for (w, &a) in weights.iter_mut().zip(addrs.iter()) { - if ((a >> bit_idx) & 1) == 1 { - *w *= r; - } else { - *w *= r0; - } - } -} - -// ============================================================================ -// Address-lane oracles (Track A sparse-in-time variants) -// ============================================================================ - -fn addr_from_sparse_bits_at_time(bit_cols: &[SparseIdxVec], t: usize) -> usize { - let mut out = 0usize; - for (b, col) in bit_cols.iter().enumerate() { - if col.get(t) == K::ONE { - out |= 1usize << b; - } - } - out -} - -fn merge_sparse_time_indices(a: &[(usize, K)], b: &[(usize, K)]) -> Vec { - let mut out = Vec::with_capacity(a.len() + b.len()); - let mut i = 0usize; - let mut j = 0usize; - while i < a.len() || j < b.len() { - let next = match (a.get(i), b.get(j)) { - (Some(&(ai, _)), Some(&(bj, _))) => { - if ai <= bj { - i += 1; - ai - } else { - j += 1; - bj - } - } - (Some(&(ai, _)), None) => { - i += 1; - ai - } - (None, Some(&(bj, _))) => { - j += 1; - bj - } - (None, None) => break, - }; - if out.last().copied() != Some(next) { - out.push(next); - } - } - out -} - -/// Sparse-in-time address-lane prefix oracle for the Twist read-check. -/// -/// Same proof semantics as `TwistReadCheckAddrOracle`, but internal time iteration scales with -/// activity (`nnz(has_read) + nnz(has_write)`), not `T = 2^ell_n`. -pub struct TwistReadCheckAddrOracleSparseTime { - ell_addr: usize, - bit_idx: usize, - degree_bound: usize, - - mem_scratch: std::collections::HashMap, - - // Per-event (sparse time list) arrays, sorted by time index. - eq_cycle: Vec, - has_read: Vec, - rv: Vec, - has_write: Vec, - inc_at_write_addr: Vec, - ra_addrs: Vec, - wa_addrs: Vec, - ra_prefix_w: Vec, - wa_prefix_w: Vec, - - init_addrs: Vec, - init_vals: Vec, - init_prefix_w: Vec, +pub struct TwistValEvalOracleSparseTime { + core: TwistWriteEqAddrOracleSparseTimeCore, } -impl TwistReadCheckAddrOracleSparseTime { - #[allow(clippy::too_many_arguments)] +impl TwistValEvalOracleSparseTime { pub fn new( - init_sparse: Vec<(usize, K)>, - r_cycle: &[K], - has_read: SparseIdxVec, - rv: SparseIdxVec, - ra_bits: &[SparseIdxVec], + wa_bits: Vec>, has_write: SparseIdxVec, - wa_bits: &[SparseIdxVec], inc_at_write_addr: SparseIdxVec, - ) -> Self { - let pow2_time = 1usize << r_cycle.len(); - assert_eq!(has_read.len(), pow2_time, "has_read length must match time domain"); - assert_eq!(rv.len(), pow2_time, "rv length must match time domain"); - assert_eq!(has_write.len(), pow2_time, "has_write length must match time domain"); - assert_eq!( - inc_at_write_addr.len(), - pow2_time, - "inc_at_write_addr length must match time domain" - ); - - let ell_addr = ra_bits.len(); - assert_eq!(wa_bits.len(), ell_addr, "wa_bits/ra_bits length mismatch"); - let pow2_addr = 1usize << ell_addr; - for (addr, _) in init_sparse.iter() { - assert!(*addr < pow2_addr, "init address out of range"); - } - for (b, col) in ra_bits.iter().enumerate() { - assert_eq!(col.len(), pow2_time, "ra_bits[{b}] length mismatch"); - } - for (b, col) in wa_bits.iter().enumerate() { - assert_eq!(col.len(), pow2_time, "wa_bits[{b}] length mismatch"); - } - - let times = merge_sparse_time_indices(has_read.entries(), has_write.entries()); - let mut eq_cycle_out = Vec::with_capacity(times.len()); - let mut has_read_out = Vec::with_capacity(times.len()); - let mut rv_out = Vec::with_capacity(times.len()); - let mut has_write_out = Vec::with_capacity(times.len()); - let mut inc_out = Vec::with_capacity(times.len()); - let mut ra_addrs = Vec::with_capacity(times.len()); - let mut wa_addrs = Vec::with_capacity(times.len()); - - for &t in times.iter() { - eq_cycle_out.push(chi_at_bool_index(r_cycle, t)); - let hr = has_read.get(t); - let hw = has_write.get(t); - has_read_out.push(hr); - rv_out.push(rv.get(t)); - has_write_out.push(hw); - inc_out.push(inc_at_write_addr.get(t)); - - ra_addrs.push(if hr != K::ZERO { - addr_from_sparse_bits_at_time(ra_bits, t) - } else { - 0 - }); - wa_addrs.push(if hw != K::ZERO { - addr_from_sparse_bits_at_time(wa_bits, t) - } else { - 0 - }); - } - - let (init_addrs, init_vals): (Vec, Vec) = init_sparse.into_iter().unzip(); - Self { - ell_addr, - bit_idx: 0, - degree_bound: 2, - mem_scratch: std::collections::HashMap::with_capacity(init_addrs.len()), - eq_cycle: eq_cycle_out, - has_read: has_read_out, - rv: rv_out, - has_write: has_write_out, - inc_at_write_addr: inc_out, - ra_addrs, - wa_addrs, - ra_prefix_w: vec![K::ONE; times.len()], - wa_prefix_w: vec![K::ONE; times.len()], - init_prefix_w: vec![K::ONE; init_addrs.len()], - init_addrs, - init_vals, - } - } -} - -impl RoundOracle for TwistReadCheckAddrOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.num_rounds() == 0 { - let mut mem = K::ZERO; - for (&val, &w) in self.init_vals.iter().zip(self.init_prefix_w.iter()) { - mem += val * w; - } - - let mut sum = K::ZERO; - for i in 0..self.eq_cycle.len() { - let eq_t = self.eq_cycle[i]; - let gate_r = self.has_read[i]; - if gate_r != K::ZERO { - sum += eq_t * gate_r * self.ra_prefix_w[i] * (mem - self.rv[i]); - } - - let gate_w = self.has_write[i]; - if gate_w != K::ZERO { - mem += self.inc_at_write_addr[i] * gate_w * self.wa_prefix_w[i]; - } - } - return vec![sum; points.len()]; - } - - let bit_idx = self.bit_idx; - let mut ys = vec![K::ZERO; points.len()]; - - self.mem_scratch.clear(); - let mem = &mut self.mem_scratch; - for ((&addr, &val), &w) in self - .init_addrs - .iter() - .zip(self.init_vals.iter()) - .zip(self.init_prefix_w.iter()) - { - let idx = addr >> bit_idx; - let contrib = val * w; - if contrib != K::ZERO { - *mem.entry(idx).or_insert(K::ZERO) += contrib; - } - } - - for i in 0..self.eq_cycle.len() { - let eq_t = self.eq_cycle[i]; - - let gate_r = self.has_read[i]; - if gate_r != K::ZERO { - let ra = self.ra_addrs[i]; - let base = ra >> (bit_idx + 1); - let idx0 = base * 2; - let idx1 = idx0 + 1; - let v0 = mem.get(&idx0).copied().unwrap_or(K::ZERO); - let v1 = mem.get(&idx1).copied().unwrap_or(K::ZERO); - let dv = v1 - v0; - let rv_t = self.rv[i]; - let prefix = self.ra_prefix_w[i]; - let bit = (ra >> bit_idx) & 1; - - for (j, &x) in points.iter().enumerate() { - let val_x = v0 + dv * x; - let addr_factor = if bit == 1 { x } else { K::ONE - x }; - ys[j] += eq_t * gate_r * prefix * addr_factor * (val_x - rv_t); - } - } - - let gate_w = self.has_write[i]; - if gate_w != K::ZERO { - let wa = self.wa_addrs[i]; - let idx = wa >> bit_idx; - let delta = self.inc_at_write_addr[i] * gate_w * self.wa_prefix_w[i]; - if delta != K::ZERO { - *mem.entry(idx).or_insert(K::ZERO) += delta; - } - } - } - - ys - } - - fn num_rounds(&self) -> usize { - self.ell_addr.saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - update_prefix_weights_in_place(&mut self.init_prefix_w, &self.init_addrs, self.bit_idx, r); - update_prefix_weights_in_place(&mut self.ra_prefix_w, &self.ra_addrs, self.bit_idx, r); - update_prefix_weights_in_place(&mut self.wa_prefix_w, &self.wa_addrs, self.bit_idx, r); - self.bit_idx += 1; + r_addr: &[K], + r_time: &[K], + ) -> (Self, K) { + let (core, claim) = + TwistWriteEqAddrOracleSparseTimeCore::new_val(wa_bits, has_write, inc_at_write_addr, r_addr, r_time); + (Self { core }, claim) } } +impl_round_oracle_via_core!(TwistValEvalOracleSparseTime); -/// Sparse-in-time address-lane prefix oracle for the Twist write-check. -pub struct TwistWriteCheckAddrOracleSparseTime { - ell_addr: usize, - bit_idx: usize, - degree_bound: usize, - - mem_scratch: std::collections::HashMap, - - eq_cycle: Vec, - has_write: Vec, - wv: Vec, - inc_at_write_addr: Vec, - wa_addrs: Vec, - wa_prefix_w: Vec, - - init_addrs: Vec, - init_vals: Vec, - init_prefix_w: Vec, +pub struct TwistTotalIncOracleSparseTime { + core: TwistWriteEqAddrOracleSparseTimeCore, } -impl TwistWriteCheckAddrOracleSparseTime { - #[allow(clippy::too_many_arguments)] +impl TwistTotalIncOracleSparseTime { pub fn new( - init_sparse: Vec<(usize, K)>, - r_cycle: &[K], + wa_bits: Vec>, has_write: SparseIdxVec, - wv: SparseIdxVec, - wa_bits: &[SparseIdxVec], inc_at_write_addr: SparseIdxVec, - ) -> Self { - let pow2_time = 1usize << r_cycle.len(); - assert_eq!(has_write.len(), pow2_time, "has_write length must match time domain"); - assert_eq!(wv.len(), pow2_time, "wv length must match time domain"); - assert_eq!( - inc_at_write_addr.len(), - pow2_time, - "inc_at_write_addr length must match time domain" - ); - - let ell_addr = wa_bits.len(); - let pow2_addr = 1usize << ell_addr; - for (addr, _) in init_sparse.iter() { - assert!(*addr < pow2_addr, "init address out of range"); - } - for (b, col) in wa_bits.iter().enumerate() { - assert_eq!(col.len(), pow2_time, "wa_bits[{b}] length mismatch"); - } - - let times: Vec = has_write.entries().iter().map(|(t, _)| *t).collect(); - let mut eq_cycle_out = Vec::with_capacity(times.len()); - let mut has_write_out = Vec::with_capacity(times.len()); - let mut wv_out = Vec::with_capacity(times.len()); - let mut inc_out = Vec::with_capacity(times.len()); - let mut wa_addrs = Vec::with_capacity(times.len()); - - for &t in times.iter() { - eq_cycle_out.push(chi_at_bool_index(r_cycle, t)); - let hw = has_write.get(t); - has_write_out.push(hw); - wv_out.push(wv.get(t)); - inc_out.push(inc_at_write_addr.get(t)); - wa_addrs.push(addr_from_sparse_bits_at_time(wa_bits, t)); - } - - let (init_addrs, init_vals): (Vec, Vec) = init_sparse.into_iter().unzip(); - Self { - ell_addr, - bit_idx: 0, - degree_bound: 2, - mem_scratch: std::collections::HashMap::with_capacity(init_addrs.len()), - eq_cycle: eq_cycle_out, - has_write: has_write_out, - wv: wv_out, - inc_at_write_addr: inc_out, - wa_addrs, - init_prefix_w: vec![K::ONE; init_addrs.len()], - init_addrs, - init_vals, - wa_prefix_w: vec![K::ONE; times.len()], - } + r_addr: &[K], + ) -> (Self, K) { + let (core, claim) = + TwistWriteEqAddrOracleSparseTimeCore::new_total(wa_bits, has_write, inc_at_write_addr, r_addr); + (Self { core }, claim) } } +impl_round_oracle_via_core!(TwistTotalIncOracleSparseTime); -impl RoundOracle for TwistWriteCheckAddrOracleSparseTime { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.num_rounds() == 0 { - let mut mem = K::ZERO; - for (&val, &w) in self.init_vals.iter().zip(self.init_prefix_w.iter()) { - mem += val * w; - } - - let mut sum = K::ZERO; - for i in 0..self.eq_cycle.len() { - let delta = self.wv[i] - mem - self.inc_at_write_addr[i]; - sum += self.eq_cycle[i] * self.has_write[i] * self.wa_prefix_w[i] * delta; - - let gate_w = self.has_write[i]; - if gate_w != K::ZERO { - mem += self.inc_at_write_addr[i] * gate_w * self.wa_prefix_w[i]; - } - } - return vec![sum; points.len()]; - } - - let bit_idx = self.bit_idx; - let mut ys = vec![K::ZERO; points.len()]; - - self.mem_scratch.clear(); - let mem = &mut self.mem_scratch; - for ((&addr, &val), &w) in self - .init_addrs - .iter() - .zip(self.init_vals.iter()) - .zip(self.init_prefix_w.iter()) - { - let idx = addr >> bit_idx; - let contrib = val * w; - if contrib != K::ZERO { - *mem.entry(idx).or_insert(K::ZERO) += contrib; - } - } - - for i in 0..self.eq_cycle.len() { - let eq_t = self.eq_cycle[i]; - let gate = self.has_write[i]; - if gate != K::ZERO { - let wa = self.wa_addrs[i]; - let base = wa >> (bit_idx + 1); - let idx0 = base * 2; - let idx1 = idx0 + 1; - let v0 = mem.get(&idx0).copied().unwrap_or(K::ZERO); - let v1 = mem.get(&idx1).copied().unwrap_or(K::ZERO); - let dv = v1 - v0; - let wv_t = self.wv[i]; - let inc_t = self.inc_at_write_addr[i]; - let prefix = self.wa_prefix_w[i]; - let bit = (wa >> bit_idx) & 1; - - for (j, &x) in points.iter().enumerate() { - let val_x = v0 + dv * x; - let addr_factor = if bit == 1 { x } else { K::ONE - x }; - ys[j] += eq_t * gate * prefix * addr_factor * (wv_t - val_x - inc_t); - } - } - - let gate_w = self.has_write[i]; - if gate_w != K::ZERO { - let wa = self.wa_addrs[i]; - let idx = wa >> bit_idx; - let delta = self.inc_at_write_addr[i] * gate_w * self.wa_prefix_w[i]; - if delta != K::ZERO { - *mem.entry(idx).or_insert(K::ZERO) += delta; - } - } +fn update_prefix_weights_in_place(weights: &mut [K], addrs: I, bit_idx: usize, r: K) +where + I: IntoIterator, +{ + let r0 = K::ONE - r; + for (w, a) in weights.iter_mut().zip(addrs) { + if ((a >> bit_idx) & 1) == 1 { + *w *= r; + } else { + *w *= r0; } - - ys } +} - fn num_rounds(&self) -> usize { - self.ell_addr.saturating_sub(self.bit_idx) +fn addr_from_sparse_bits_at_time(bit_cols: &[SparseIdxVec], t: usize) -> usize { + let mut out = 0usize; + for (b, col) in bit_cols.iter().enumerate() { + if col.get(t) == K::ONE { + out |= 1usize << b; + } } + out +} - fn degree_bound(&self) -> usize { - self.degree_bound +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum AddrEventKind { + Read, + Write, +} + +#[derive(Clone, Copy, Debug)] +struct AddrEvent { + t: usize, + kind: AddrEventKind, + chi_t: K, + gate: K, + addr: usize, + val: K, + inc: K, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum AddrCheckMode { + Read, + Write, +} + +impl AddrCheckMode { + fn wants_check(self, kind: AddrEventKind) -> bool { + matches!( + (self, kind), + (AddrCheckMode::Read, AddrEventKind::Read) | (AddrCheckMode::Write, AddrEventKind::Write) + ) } - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; + fn check_expr(self, mem_x: K, val: K, inc: K) -> K { + match self { + AddrCheckMode::Read => mem_x - val, + AddrCheckMode::Write => val - mem_x - inc, } - update_prefix_weights_in_place(&mut self.init_prefix_w, &self.init_addrs, self.bit_idx, r); - update_prefix_weights_in_place(&mut self.wa_prefix_w, &self.wa_addrs, self.bit_idx, r); - self.bit_idx += 1; } } -/// Multi-lane variant of `TwistReadCheckAddrOracleSparseTime`. -/// -/// This oracle supports multiple Twist access lanes per CPU step by treating each lane's read/write -/// activity as an independent sparse event stream, ordered by `(time, op_kind, lane)` where -/// `op_kind` is read-before-write. -/// -/// Semantics match Track A: reads observe pre-state at time `t`, writes are applied after reads at -/// the same `t`. Multiple writes to the same address at the same `t` are **not** supported by this -/// oracle (the caller must disallow or canonicalize them). -pub struct TwistReadCheckAddrOracleSparseTimeMultiLane { +struct AddrLaneCheckCore { ell_addr: usize, bit_idx: usize, - degree_bound: usize, + mode: AddrCheckMode, - mem_scratch: std::collections::HashMap, - - // Per-event (sparse time list) arrays, sorted by (time, op_kind, lane). - eq_cycle: Vec, - has_read: Vec, - rv: Vec, - has_write: Vec, - inc_at_write_addr: Vec, - ra_addrs: Vec, - wa_addrs: Vec, - ra_prefix_w: Vec, - wa_prefix_w: Vec, + events: Vec, + event_prefix_w: Vec, init_addrs: Vec, init_vals: Vec, init_prefix_w: Vec, -} -impl TwistReadCheckAddrOracleSparseTimeMultiLane { - pub fn new(init_sparse: Vec<(usize, K)>, r_cycle: &[K], lanes: &[TwistLaneSparseCols]) -> Self { - assert!(!lanes.is_empty(), "multi-lane Twist oracle requires at least 1 lane"); + mem_scratch: std::collections::HashMap, +} - let pow2_time = 1usize << r_cycle.len(); - let ell_addr = lanes[0].ra_bits.len(); +impl AddrLaneCheckCore { + fn new(init_sparse: Vec<(usize, K)>, ell_addr: usize, mode: AddrCheckMode, mut events: Vec) -> Self { let pow2_addr = 1usize << ell_addr; - for (addr, _) in init_sparse.iter() { - assert!(*addr < pow2_addr, "init address out of range"); + debug_assert!(*addr < pow2_addr, "init address out of range"); } - - for (lane_idx, lane) in lanes.iter().enumerate() { - assert_eq!( - lane.has_read.len(), - pow2_time, - "has_read length must match time domain (lane={lane_idx})" - ); - assert_eq!( - lane.rv.len(), - pow2_time, - "rv length must match time domain (lane={lane_idx})" - ); - assert_eq!( - lane.has_write.len(), - pow2_time, - "has_write length must match time domain (lane={lane_idx})" - ); - assert_eq!( - lane.inc_at_write_addr.len(), - pow2_time, - "inc_at_write_addr length must match time domain (lane={lane_idx})" - ); - assert_eq!( - lane.ra_bits.len(), - ell_addr, - "ra_bits count must match ell_addr (lane={lane_idx})" - ); - assert_eq!( - lane.wa_bits.len(), - ell_addr, - "wa_bits count must match ell_addr (lane={lane_idx})" - ); - for (b, col) in lane.ra_bits.iter().enumerate() { - assert_eq!(col.len(), pow2_time, "ra_bits[{b}] length mismatch (lane={lane_idx})"); - } - for (b, col) in lane.wa_bits.iter().enumerate() { - assert_eq!(col.len(), pow2_time, "wa_bits[{b}] length mismatch (lane={lane_idx})"); - } + for event in events.iter() { + debug_assert!(event.addr < pow2_addr, "event address out of range"); } - // Collect per-lane sparse events: reads first, then writes, at each time. - let mut events: Vec<(usize, u8, usize)> = Vec::new(); - for (lane_idx, lane) in lanes.iter().enumerate() { - events.extend( - lane.has_read - .entries() - .iter() - .map(|&(t, _)| (t, 0u8, lane_idx)), - ); - events.extend( - lane.has_write - .entries() - .iter() - .map(|&(t, _)| (t, 1u8, lane_idx)), - ); - } - events.sort_unstable_by_key(|(t, kind, lane)| (*t, *kind, *lane)); - - let mut eq_cycle_out = Vec::with_capacity(events.len()); - let mut has_read_out = Vec::with_capacity(events.len()); - let mut rv_out = Vec::with_capacity(events.len()); - let mut has_write_out = Vec::with_capacity(events.len()); - let mut inc_out = Vec::with_capacity(events.len()); - let mut ra_addrs = Vec::with_capacity(events.len()); - let mut wa_addrs = Vec::with_capacity(events.len()); - - for (t, kind, lane_idx) in events.into_iter() { - let lane = &lanes[lane_idx]; - eq_cycle_out.push(chi_at_bool_index(r_cycle, t)); - if kind == 0 { - // Read - let hr = lane.has_read.get(t); - has_read_out.push(hr); - rv_out.push(lane.rv.get(t)); - has_write_out.push(K::ZERO); - inc_out.push(K::ZERO); - ra_addrs.push(if hr != K::ZERO { - addr_from_sparse_bits_at_time(&lane.ra_bits, t) - } else { - 0 - }); - wa_addrs.push(0); - } else { - // Write - let hw = lane.has_write.get(t); - has_read_out.push(K::ZERO); - rv_out.push(K::ZERO); - has_write_out.push(hw); - inc_out.push(lane.inc_at_write_addr.get(t)); - ra_addrs.push(0); - wa_addrs.push(if hw != K::ZERO { - addr_from_sparse_bits_at_time(&lane.wa_bits, t) - } else { - 0 - }); - } - } + events.sort_unstable_by_key(|e| (e.t, e.kind)); - let events_len = eq_cycle_out.len(); let (init_addrs, init_vals): (Vec, Vec) = init_sparse.into_iter().unzip(); + let n_events = events.len(); + let init_len = init_addrs.len(); + Self { ell_addr, bit_idx: 0, - degree_bound: 2, - mem_scratch: std::collections::HashMap::with_capacity(init_addrs.len()), - eq_cycle: eq_cycle_out, - has_read: has_read_out, - rv: rv_out, - has_write: has_write_out, - inc_at_write_addr: inc_out, - ra_addrs, - wa_addrs, - ra_prefix_w: vec![K::ONE; events_len], - wa_prefix_w: vec![K::ONE; events_len], - init_prefix_w: vec![K::ONE; init_addrs.len()], + mode, + events, + event_prefix_w: vec![K::ONE; n_events], + init_prefix_w: vec![K::ONE; init_len], init_addrs, init_vals, + mem_scratch: std::collections::HashMap::with_capacity(init_len), } } } -impl RoundOracle for TwistReadCheckAddrOracleSparseTimeMultiLane { +impl RoundOracle for AddrLaneCheckCore { fn evals_at(&mut self, points: &[K]) -> Vec { if self.num_rounds() == 0 { let mut mem = K::ZERO; @@ -6826,22 +2929,40 @@ impl RoundOracle for TwistReadCheckAddrOracleSparseTimeMultiLane { } let mut sum = K::ZERO; - for i in 0..self.eq_cycle.len() { - let diff = mem - self.rv[i]; - sum += self.eq_cycle[i] * self.has_read[i] * self.ra_prefix_w[i] * diff; + let mut i = 0usize; + while i < self.events.len() { + let t = self.events[i].t; + let start = i; + while i < self.events.len() && self.events[i].t == t { + i += 1; + } + let end = i; + + for k in start..end { + let event = self.events[k]; + if !self.mode.wants_check(event.kind) || event.gate == K::ZERO { + continue; + } + sum += event.chi_t + * event.gate + * self.event_prefix_w[k] + * self.mode.check_expr(mem, event.val, event.inc); + } - let gate_w = self.has_write[i]; - if gate_w != K::ZERO { - mem += self.inc_at_write_addr[i] * gate_w * self.wa_prefix_w[i]; + for k in start..end { + let event = self.events[k]; + if event.kind != AddrEventKind::Write || event.gate == K::ZERO { + continue; + } + mem += event.inc * event.gate * self.event_prefix_w[k]; } } + return vec![sum; points.len()]; } - let bit_idx = self.bit_idx; - let mut ys = vec![K::ZERO; points.len()]; - self.mem_scratch.clear(); + let bit_idx = self.bit_idx; let mem = &mut self.mem_scratch; for ((&addr, &val), &w) in self .init_addrs @@ -6856,33 +2977,46 @@ impl RoundOracle for TwistReadCheckAddrOracleSparseTimeMultiLane { } } - for i in 0..self.eq_cycle.len() { - let eq_t = self.eq_cycle[i]; - let gate_r = self.has_read[i]; - if gate_r != K::ZERO { - let ra = self.ra_addrs[i]; - let base = ra >> (bit_idx + 1); + let mut ys = vec![K::ZERO; points.len()]; + let mut i = 0usize; + while i < self.events.len() { + let t = self.events[i].t; + let start = i; + while i < self.events.len() && self.events[i].t == t { + i += 1; + } + let end = i; + + for k in start..end { + let event = self.events[k]; + if !self.mode.wants_check(event.kind) || event.gate == K::ZERO { + continue; + } + + let base = event.addr >> (bit_idx + 1); let idx0 = base * 2; let idx1 = idx0 + 1; let v0 = mem.get(&idx0).copied().unwrap_or(K::ZERO); let v1 = mem.get(&idx1).copied().unwrap_or(K::ZERO); let dv = v1 - v0; - let rv_t = self.rv[i]; - let prefix = self.ra_prefix_w[i]; - let bit = (ra >> bit_idx) & 1; + let bit = (event.addr >> bit_idx) & 1; + let pref = self.event_prefix_w[k]; + let coef = event.chi_t * event.gate * pref; for (j, &x) in points.iter().enumerate() { - let val_x = v0 + dv * x; + let mem_x = v0 + dv * x; let addr_factor = if bit == 1 { x } else { K::ONE - x }; - ys[j] += eq_t * gate_r * prefix * addr_factor * (val_x - rv_t); + ys[j] += coef * addr_factor * self.mode.check_expr(mem_x, event.val, event.inc); } } - let gate_w = self.has_write[i]; - if gate_w != K::ZERO { - let wa = self.wa_addrs[i]; - let idx = wa >> bit_idx; - let delta = self.inc_at_write_addr[i] * gate_w * self.wa_prefix_w[i]; + for k in start..end { + let event = self.events[k]; + if event.kind != AddrEventKind::Write || event.gate == K::ZERO { + continue; + } + let idx = event.addr >> bit_idx; + let delta = event.inc * event.gate * self.event_prefix_w[k]; if delta != K::ZERO { *mem.entry(idx).or_insert(K::ZERO) += delta; } @@ -6897,275 +3031,285 @@ impl RoundOracle for TwistReadCheckAddrOracleSparseTimeMultiLane { } fn degree_bound(&self) -> usize { - self.degree_bound + 2 } fn fold(&mut self, r: K) { if self.num_rounds() == 0 { return; } - update_prefix_weights_in_place(&mut self.init_prefix_w, &self.init_addrs, self.bit_idx, r); - update_prefix_weights_in_place(&mut self.ra_prefix_w, &self.ra_addrs, self.bit_idx, r); - update_prefix_weights_in_place(&mut self.wa_prefix_w, &self.wa_addrs, self.bit_idx, r); + update_prefix_weights_in_place( + &mut self.init_prefix_w, + self.init_addrs.iter().copied(), + self.bit_idx, + r, + ); + update_prefix_weights_in_place( + &mut self.event_prefix_w, + self.events.iter().map(|e| e.addr), + self.bit_idx, + r, + ); self.bit_idx += 1; } } -/// Multi-lane variant of `TwistWriteCheckAddrOracleSparseTime`. -/// -/// Semantics are identical: multiple writes per time are allowed as independent sparse events. -pub struct TwistWriteCheckAddrOracleSparseTimeMultiLane { - ell_addr: usize, - bit_idx: usize, - degree_bound: usize, - - mem_scratch: std::collections::HashMap, - - time_idxs: Vec, - eq_cycle: Vec, - has_write: Vec, - wv: Vec, - inc_at_write_addr: Vec, - wa_addrs: Vec, - wa_prefix_w: Vec, - - init_addrs: Vec, - init_vals: Vec, - init_prefix_w: Vec, +fn push_read_events( + out: &mut Vec, + r_cycle: &[K], + has_read: &SparseIdxVec, + rv: &SparseIdxVec, + ra_bits: &[SparseIdxVec], +) { + for &(t, gate) in has_read.entries() { + if gate == K::ZERO { + continue; + } + out.push(AddrEvent { + t, + kind: AddrEventKind::Read, + chi_t: chi_at_bool_index(r_cycle, t), + gate, + addr: addr_from_sparse_bits_at_time(ra_bits, t), + val: rv.get(t), + inc: K::ZERO, + }); + } } -impl TwistWriteCheckAddrOracleSparseTimeMultiLane { - pub fn new(init_sparse: Vec<(usize, K)>, r_cycle: &[K], lanes: &[TwistLaneSparseCols]) -> Self { - assert!(!lanes.is_empty(), "multi-lane Twist oracle requires at least 1 lane"); +fn push_write_events( + out: &mut Vec, + r_cycle: &[K], + has_write: &SparseIdxVec, + wa_bits: &[SparseIdxVec], + inc_at_write_addr: &SparseIdxVec, + wv: Option<&SparseIdxVec>, +) { + for &(t, gate) in has_write.entries() { + if gate == K::ZERO { + continue; + } + out.push(AddrEvent { + t, + kind: AddrEventKind::Write, + chi_t: chi_at_bool_index(r_cycle, t), + gate, + addr: addr_from_sparse_bits_at_time(wa_bits, t), + val: wv.map(|col| col.get(t)).unwrap_or(K::ZERO), + inc: inc_at_write_addr.get(t), + }); + } +} - let pow2_time = 1usize << r_cycle.len(); - let ell_addr = lanes[0].wa_bits.len(); - let pow2_addr = 1usize << ell_addr; +fn assert_init_sparse_in_range(init_sparse: &[(usize, K)], ell_addr: usize) { + let pow2_addr = 1usize << ell_addr; + for (addr, _) in init_sparse.iter() { + assert!(*addr < pow2_addr); + } +} - for (addr, _) in init_sparse.iter() { - assert!(*addr < pow2_addr, "init address out of range"); - } +fn assert_cols_match_time(cols: &[SparseIdxVec], pow2_time: usize) { + for col in cols { + debug_assert_eq!(col.len(), pow2_time); + } +} - for (lane_idx, lane) in lanes.iter().enumerate() { - assert_eq!( - lane.has_write.len(), - pow2_time, - "has_write length must match time domain (lane={lane_idx})" - ); - assert_eq!( - lane.wv.len(), - pow2_time, - "wv length must match time domain (lane={lane_idx})" - ); - assert_eq!( - lane.inc_at_write_addr.len(), - pow2_time, - "inc_at_write_addr length must match time domain (lane={lane_idx})" +#[allow(clippy::too_many_arguments)] +fn collect_singlelane_read_addr_events( + r_cycle: &[K], + has_read: &SparseIdxVec, + rv: &SparseIdxVec, + ra_bits: &[SparseIdxVec], + has_write: &SparseIdxVec, + wa_bits: &[SparseIdxVec], + inc_at_write_addr: &SparseIdxVec, +) -> (usize, Vec) { + let pow2_time = 1usize << r_cycle.len(); + let ell_addr = ra_bits.len(); + + debug_assert_eq!(has_read.len(), pow2_time); + debug_assert_eq!(rv.len(), pow2_time); + debug_assert_eq!(has_write.len(), pow2_time); + debug_assert_eq!(inc_at_write_addr.len(), pow2_time); + debug_assert_eq!(wa_bits.len(), ell_addr); + assert_cols_match_time(ra_bits, pow2_time); + assert_cols_match_time(wa_bits, pow2_time); + + let mut events = Vec::new(); + push_read_events(&mut events, r_cycle, has_read, rv, ra_bits); + push_write_events(&mut events, r_cycle, has_write, wa_bits, inc_at_write_addr, None); + (ell_addr, events) +} + +fn collect_singlelane_write_addr_events( + r_cycle: &[K], + has_write: &SparseIdxVec, + wv: &SparseIdxVec, + wa_bits: &[SparseIdxVec], + inc_at_write_addr: &SparseIdxVec, +) -> (usize, Vec) { + let pow2_time = 1usize << r_cycle.len(); + let ell_addr = wa_bits.len(); + + debug_assert_eq!(has_write.len(), pow2_time); + debug_assert_eq!(wv.len(), pow2_time); + debug_assert_eq!(inc_at_write_addr.len(), pow2_time); + assert_cols_match_time(wa_bits, pow2_time); + + let mut events = Vec::new(); + push_write_events(&mut events, r_cycle, has_write, wa_bits, inc_at_write_addr, Some(wv)); + (ell_addr, events) +} + +fn collect_multilane_addr_events( + r_cycle: &[K], + lanes: &[TwistLaneSparseCols], + mode: AddrCheckMode, +) -> (usize, Vec) { + assert!(!lanes.is_empty()); + let pow2_time = 1usize << r_cycle.len(); + let ell_addr = match mode { + AddrCheckMode::Read => lanes[0].ra_bits.len(), + AddrCheckMode::Write => lanes[0].wa_bits.len(), + }; + let mut events = Vec::new(); + + for lane in lanes { + debug_assert_eq!(lane.has_write.len(), pow2_time); + debug_assert_eq!(lane.inc_at_write_addr.len(), pow2_time); + debug_assert_eq!(lane.wa_bits.len(), ell_addr); + assert_cols_match_time(&lane.wa_bits, pow2_time); + + if matches!(mode, AddrCheckMode::Read) { + debug_assert_eq!(lane.has_read.len(), pow2_time); + debug_assert_eq!(lane.rv.len(), pow2_time); + debug_assert_eq!(lane.ra_bits.len(), ell_addr); + assert_cols_match_time(&lane.ra_bits, pow2_time); + push_read_events(&mut events, r_cycle, &lane.has_read, &lane.rv, &lane.ra_bits); + push_write_events( + &mut events, + r_cycle, + &lane.has_write, + &lane.wa_bits, + &lane.inc_at_write_addr, + None, ); - assert_eq!( - lane.wa_bits.len(), - ell_addr, - "wa_bits count must match ell_addr (lane={lane_idx})" + } else { + debug_assert_eq!(lane.wv.len(), pow2_time); + push_write_events( + &mut events, + r_cycle, + &lane.has_write, + &lane.wa_bits, + &lane.inc_at_write_addr, + Some(&lane.wv), ); - for (b, col) in lane.wa_bits.iter().enumerate() { - assert_eq!(col.len(), pow2_time, "wa_bits[{b}] length mismatch (lane={lane_idx})"); - } } + } + (ell_addr, events) +} - let mut events: Vec<(usize, usize)> = Vec::new(); - for (lane_idx, lane) in lanes.iter().enumerate() { - events.extend(lane.has_write.entries().iter().map(|&(t, _)| (t, lane_idx))); - } - events.sort_unstable_by_key(|(t, lane)| (*t, *lane)); - - let mut eq_cycle_out = Vec::with_capacity(events.len()); - let mut has_write_out = Vec::with_capacity(events.len()); - let mut wv_out = Vec::with_capacity(events.len()); - let mut inc_out = Vec::with_capacity(events.len()); - let mut wa_addrs = Vec::with_capacity(events.len()); - let mut time_idxs = Vec::with_capacity(events.len()); - - for (t, lane_idx) in events.into_iter() { - let lane = &lanes[lane_idx]; - let hw = lane.has_write.get(t); - time_idxs.push(t); - eq_cycle_out.push(chi_at_bool_index(r_cycle, t)); - has_write_out.push(hw); - wv_out.push(lane.wv.get(t)); - inc_out.push(lane.inc_at_write_addr.get(t)); - wa_addrs.push(addr_from_sparse_bits_at_time(&lane.wa_bits, t)); - } +pub struct TwistReadCheckAddrOracleSparseTime { + core: AddrLaneCheckCore, +} + +impl TwistReadCheckAddrOracleSparseTime { + #[allow(clippy::too_many_arguments)] + pub fn new( + init_sparse: Vec<(usize, K)>, + r_cycle: &[K], + has_read: SparseIdxVec, + rv: SparseIdxVec, + ra_bits: &[SparseIdxVec], + has_write: SparseIdxVec, + wa_bits: &[SparseIdxVec], + inc_at_write_addr: SparseIdxVec, + ) -> Self { + let (ell_addr, events) = collect_singlelane_read_addr_events( + r_cycle, + &has_read, + &rv, + ra_bits, + &has_write, + wa_bits, + &inc_at_write_addr, + ); + assert_init_sparse_in_range(&init_sparse, ell_addr); - let events_len = eq_cycle_out.len(); - let (init_addrs, init_vals): (Vec, Vec) = init_sparse.into_iter().unzip(); Self { - ell_addr, - bit_idx: 0, - degree_bound: 2, - mem_scratch: std::collections::HashMap::with_capacity(init_addrs.len()), - time_idxs, - eq_cycle: eq_cycle_out, - has_write: has_write_out, - wv: wv_out, - inc_at_write_addr: inc_out, - wa_addrs, - init_prefix_w: vec![K::ONE; init_addrs.len()], - init_addrs, - init_vals, - wa_prefix_w: vec![K::ONE; events_len], + core: AddrLaneCheckCore::new(init_sparse, ell_addr, AddrCheckMode::Read, events), } } } +impl_round_oracle_via_core!(TwistReadCheckAddrOracleSparseTime); -impl RoundOracle for TwistWriteCheckAddrOracleSparseTimeMultiLane { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.num_rounds() == 0 { - let mut mem = K::ZERO; - for (&val, &w) in self.init_vals.iter().zip(self.init_prefix_w.iter()) { - mem += val * w; - } - - let mut sum = K::ZERO; - let mut i = 0usize; - while i < self.eq_cycle.len() { - let t = self.time_idxs[i]; - let start = i; - while i < self.eq_cycle.len() && self.time_idxs[i] == t { - i += 1; - } - let end = i; - - // Evaluate write-check terms at time t using pre-state mem. - for k in start..end { - let delta = self.wv[k] - mem - self.inc_at_write_addr[k]; - sum += self.eq_cycle[k] * self.has_write[k] * self.wa_prefix_w[k] * delta; - } - - // Apply all writes at time t after checks. - for k in start..end { - let gate_w = self.has_write[k]; - if gate_w != K::ZERO { - mem += self.inc_at_write_addr[k] * gate_w * self.wa_prefix_w[k]; - } - } - } - return vec![sum; points.len()]; - } +pub struct TwistWriteCheckAddrOracleSparseTime { + core: AddrLaneCheckCore, +} - let bit_idx = self.bit_idx; - let mut ys = vec![K::ZERO; points.len()]; +impl TwistWriteCheckAddrOracleSparseTime { + #[allow(clippy::too_many_arguments)] + pub fn new( + init_sparse: Vec<(usize, K)>, + r_cycle: &[K], + has_write: SparseIdxVec, + wv: SparseIdxVec, + wa_bits: &[SparseIdxVec], + inc_at_write_addr: SparseIdxVec, + ) -> Self { + let (ell_addr, events) = collect_singlelane_write_addr_events( + r_cycle, + &has_write, + &wv, + wa_bits, + &inc_at_write_addr, + ); + assert_init_sparse_in_range(&init_sparse, ell_addr); - self.mem_scratch.clear(); - let mem = &mut self.mem_scratch; - for ((&addr, &val), &w) in self - .init_addrs - .iter() - .zip(self.init_vals.iter()) - .zip(self.init_prefix_w.iter()) - { - let idx = addr >> bit_idx; - let contrib = val * w; - if contrib != K::ZERO { - *mem.entry(idx).or_insert(K::ZERO) += contrib; - } + Self { + core: AddrLaneCheckCore::new(init_sparse, ell_addr, AddrCheckMode::Write, events), } + } +} +impl_round_oracle_via_core!(TwistWriteCheckAddrOracleSparseTime); - let mut i = 0usize; - while i < self.eq_cycle.len() { - let t = self.time_idxs[i]; - let start = i; - while i < self.eq_cycle.len() && self.time_idxs[i] == t { - i += 1; - } - let end = i; +pub struct TwistReadCheckAddrOracleSparseTimeMultiLane { + core: AddrLaneCheckCore, +} - // Evaluate write-check terms at time t using pre-state mem. - for k in start..end { - let eq_t = self.eq_cycle[k]; - let gate = self.has_write[k]; - if gate != K::ZERO { - let wa = self.wa_addrs[k]; - let base = wa >> (bit_idx + 1); - let idx0 = base * 2; - let idx1 = idx0 + 1; - let v0 = mem.get(&idx0).copied().unwrap_or(K::ZERO); - let v1 = mem.get(&idx1).copied().unwrap_or(K::ZERO); - let dv = v1 - v0; - let wv_t = self.wv[k]; - let inc_t = self.inc_at_write_addr[k]; - let prefix = self.wa_prefix_w[k]; - let bit = (wa >> bit_idx) & 1; - - for (j, &x) in points.iter().enumerate() { - let val_x = v0 + dv * x; - let addr_factor = if bit == 1 { x } else { K::ONE - x }; - ys[j] += eq_t * gate * prefix * addr_factor * (wv_t - val_x - inc_t); - } - } - } +impl TwistReadCheckAddrOracleSparseTimeMultiLane { + pub fn new(init_sparse: Vec<(usize, K)>, r_cycle: &[K], lanes: &[TwistLaneSparseCols]) -> Self { + let (ell_addr, events) = collect_multilane_addr_events(r_cycle, lanes, AddrCheckMode::Read); + assert_init_sparse_in_range(&init_sparse, ell_addr); - // Apply all writes at time t after checks. - for k in start..end { - let gate_w = self.has_write[k]; - if gate_w != K::ZERO { - let wa = self.wa_addrs[k]; - let idx = wa >> bit_idx; - let delta = self.inc_at_write_addr[k] * gate_w * self.wa_prefix_w[k]; - if delta != K::ZERO { - *mem.entry(idx).or_insert(K::ZERO) += delta; - } - } - } + Self { + core: AddrLaneCheckCore::new(init_sparse, ell_addr, AddrCheckMode::Read, events), } - - ys } +} +impl_round_oracle_via_core!(TwistReadCheckAddrOracleSparseTimeMultiLane); - fn num_rounds(&self) -> usize { - self.ell_addr.saturating_sub(self.bit_idx) - } +pub struct TwistWriteCheckAddrOracleSparseTimeMultiLane { + core: AddrLaneCheckCore, +} - fn degree_bound(&self) -> usize { - self.degree_bound - } +impl TwistWriteCheckAddrOracleSparseTimeMultiLane { + pub fn new(init_sparse: Vec<(usize, K)>, r_cycle: &[K], lanes: &[TwistLaneSparseCols]) -> Self { + let (ell_addr, events) = collect_multilane_addr_events(r_cycle, lanes, AddrCheckMode::Write); + assert_init_sparse_in_range(&init_sparse, ell_addr); - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; + Self { + core: AddrLaneCheckCore::new(init_sparse, ell_addr, AddrCheckMode::Write, events), } - update_prefix_weights_in_place(&mut self.init_prefix_w, &self.init_addrs, self.bit_idx, r); - update_prefix_weights_in_place(&mut self.wa_prefix_w, &self.wa_addrs, self.bit_idx, r); - self.bit_idx += 1; } } +impl_round_oracle_via_core!(TwistWriteCheckAddrOracleSparseTimeMultiLane); -// ============================================================================ -// Address-Domain Lookup Oracle (New Architecture) -// ============================================================================ - -/// Address-domain lookup oracle for Shout. -/// -/// Uses the identity: -/// val̃(r_cycle) = Σ_{a∈{0,1}^{ℓ_A}} Table(a) · Ã(r_cycle, a) -/// -/// where Ã(r_cycle, a) is the MLE of the one-hot adapter matrix A(t,a) = 1[a = addr(t)]. -/// -/// The sum-check is over address variables (ℓ_A rounds), not time variables. -/// At the end, the verifier checks: S_final = Tablẽ(r_addr) · Ã(r_cycle, r_addr) -/// -/// ## Advantages over time-domain approach: -/// - No need to commit `table_at_addr[t]` column -/// - Direct verification using table MLE (public) and adapter evaluation -/// - Sum-check domain is address space (often smaller than time) pub struct AddressLookupOracle { - /// Product oracle for Table(a) · weight(a) over address space core: ProductRoundOracle, } impl AddressLookupOracle { - /// Create a new address-domain lookup oracle from sparse-in-time columns. - /// - /// Track A: proof statement/verifier semantics are unchanged; only time iteration is sparse. pub fn new( addr_bits: &[SparseIdxVec], has_lookup: &SparseIdxVec, @@ -7176,15 +3320,11 @@ impl AddressLookupOracle { let pow2_cycle = 1usize << r_cycle.len(); let pow2_addr = 1usize << ell_addr; - assert_eq!(addr_bits.len(), ell_addr, "addr_bits count must match ell_addr"); - for (b, col) in addr_bits.iter().enumerate() { - assert_eq!(col.len(), pow2_cycle, "addr_bits[{b}] length must match cycle domain"); + assert_eq!(addr_bits.len(), ell_addr); + for col in addr_bits.iter() { + assert_eq!(col.len(), pow2_cycle); } - assert_eq!( - has_lookup.len(), - pow2_cycle, - "has_lookup length must match cycle domain" - ); + assert_eq!(has_lookup.len(), pow2_cycle); let mut claimed_sum = K::ZERO; let mut weight_table = vec![K::ZERO; pow2_addr]; @@ -7198,15 +3338,8 @@ impl AddressLookupOracle { continue; } - let mut addr_t = 0usize; - for (b, col) in addr_bits.iter().enumerate() { - if col.get(t) == K::ONE { - addr_t |= 1usize << b; - } - } - if addr_t < pow2_addr { - weight_table[addr_t] += weight_t; - } + let addr_t = addr_from_sparse_bits_at_time(addr_bits, t); + weight_table[addr_t] += weight_t; } for addr in 0..pow2_addr.min(table.len()) { @@ -7220,12 +3353,10 @@ impl AddressLookupOracle { (Self { core }, claimed_sum) } - /// Get the final value after all rounds (should equal Tablẽ(r_addr) · Ã(r_cycle, r_addr)) pub fn final_value(&self) -> Option { self.core.value() } - /// Get the challenges accumulated during sum-check (= r_addr) pub fn challenges(&self) -> &[K] { self.core.challenges() } @@ -7233,16 +3364,12 @@ impl AddressLookupOracle { impl_round_oracle_via_core!(AddressLookupOracle); -/// Compute the MLE of a table at a random point. -/// -/// Tablẽ(r) = Σ_{a∈{0,1}^ℓ} Table[a] · eq(r, a) pub fn table_mle_eval(table: &[K], r_addr: &[K]) -> K { let ell = r_addr.len(); let pow2 = 1usize << ell; let mut result = K::ZERO; for (idx, &val) in table.iter().enumerate().take(pow2) { - // eq(r, idx) = χ_r[idx] let weight = crate::mle::chi_at_index(r_addr, idx); result += val * weight; } From 930cf42b7820ef5875d7e32e05b205833126eb8c Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Wed, 18 Feb 2026 14:06:11 -0600 Subject: [PATCH 2/2] update Signed-off-by: Nico Arqueros --- crates/neo-memory/src/twist_oracle.rs | 673 ++++++++------------------ 1 file changed, 207 insertions(+), 466 deletions(-) diff --git a/crates/neo-memory/src/twist_oracle.rs b/crates/neo-memory/src/twist_oracle.rs index 7adcdca7..086dfb29 100644 --- a/crates/neo-memory/src/twist_oracle.rs +++ b/crates/neo-memory/src/twist_oracle.rs @@ -6,6 +6,8 @@ use neo_math::K; use neo_reductions::sumcheck::RoundOracle; use p3_field::Field; use p3_field::PrimeCharacteristicRing; +use std::borrow::Cow; +use std::sync::OnceLock; macro_rules! impl_round_oracle_via_core { ($ty:ty) => { @@ -36,9 +38,9 @@ pub struct ProductRoundOracle { impl ProductRoundOracle { pub fn new(factors: Vec>, degree_bound: usize) -> Self { let len = factors.first().map(|f| f.len()).unwrap_or(1); - assert!(len.is_power_of_two(), "factor length must be a power of two"); + debug_assert!(len.is_power_of_two(), "factor length must be a power of two"); for f in factors.iter() { - assert_eq!(f.len(), len, "all factors must have the same length"); + debug_assert_eq!(f.len(), len, "all factors must have the same length"); } let total_rounds = log2_pow2(len); Self { @@ -167,6 +169,16 @@ macro_rules! for_each_sparse_parent_pair { }}; } +fn pow2_weights_32() -> &'static [K] { + static W: OnceLock<[K; 32]> = OnceLock::new(); + W.get_or_init(|| std::array::from_fn(|i| K::from_u64(1u64 << i))) +} + +fn pow2_weights_5() -> &'static [K] { + static W: OnceLock<[K; 5]> = OnceLock::new(); + W.get_or_init(|| std::array::from_fn(|i| K::from_u64(1u64 << i))) +} + fn expr_id1(cols: &[K; 1]) -> K { cols[0] } @@ -413,253 +425,6 @@ fn expr_rv32_packed_mulhu(cols: &[K; 3], limb_sum: K) -> K { lhs * rhs - limb_sum - val * K::from_u64(1u64 << 32) } -type SparseWeightedBitsExprFn = fn(&[K; N], K) -> K; - -struct SparseWeightedBitsExprOracle { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - cols: [SparseIdxVec; N], - bits: Vec>, - weights: Vec, - degree_bound: usize, - expr_fn: SparseWeightedBitsExprFn, -} - -impl SparseWeightedBitsExprOracle { - fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - cols: [SparseIdxVec; N], - bits: Vec>, - weights: Vec, - degree_bound: usize, - expr_fn: SparseWeightedBitsExprFn, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - assert_cols_match_time(&cols, 1usize << ell_n); - debug_assert_eq!(bits.len(), weights.len()); - assert_cols_match_time(&bits, 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - cols, - bits, - weights, - degree_bound, - expr_fn, - } - } -} - -impl RoundOracle for SparseWeightedBitsExprOracle { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let cols = std::array::from_fn(|i| self.cols[i].singleton_value()); - let mut sum = K::ZERO; - for (b, w) in self.bits.iter().zip(self.weights.iter()) { - sum += b.singleton_value() * *w; - } - let expr = (self.expr_fn)(&cols, sum); - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let mut ys = vec![K::ZERO; points.len()]; - for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let cols0: [K; N] = std::array::from_fn(|i| self.cols[i].get(child0)); - let cols1: [K; N] = std::array::from_fn(|i| self.cols[i].get(child1)); - let mut sum0 = K::ZERO; - let mut sum1 = K::ZERO; - for (b, w) in self.bits.iter().zip(self.weights.iter()) { - sum0 += b.get(child0) * *w; - sum1 += b.get(child1) * *w; - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let cols_x: [K; N] = std::array::from_fn(|j| interp(cols0[j], cols1[j], x)); - let sum_x = interp(sum0, sum1, x); - let expr_x = (self.expr_fn)(&cols_x, sum_x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - }); - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - for col in self.cols.iter_mut() { - col.fold_round_in_place(r); - } - for b in self.bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -type SparseValProdExprFn = fn(K, K) -> K; - -struct SparseValProdBitsOracle { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - val: SparseIdxVec, - bits: Vec>, - degree_bound: usize, - expr_fn: SparseValProdExprFn, -} - -impl SparseValProdBitsOracle { - fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - val: SparseIdxVec, - bits: Vec>, - degree_bound: usize, - expr_fn: SparseValProdExprFn, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - debug_assert_eq!(val.len(), 1usize << ell_n); - assert_cols_match_time(&bits, 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - val, - bits, - degree_bound, - expr_fn, - } - } -} - -impl RoundOracle for SparseValProdBitsOracle { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let val = self.val.singleton_value(); - let mut prod = K::ONE; - for b in self.bits.iter() { - prod *= K::ONE - b.singleton_value(); - } - let expr = (self.expr_fn)(val, prod); - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let mut ys = vec![K::ZERO; points.len()]; - for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let val0 = self.val.get(child0); - let val1 = self.val.get(child1); - let bit_pairs: Vec<(K, K)> = self - .bits - .iter() - .map(|b| (b.get(child0), b.get(child1))) - .collect(); - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let val_x = interp(val0, val1, x); - let mut prod_x = K::ONE; - for &(b0, b1) in bit_pairs.iter() { - let bit_x = interp(b0, b1, x); - prod_x *= K::ONE - bit_x; - } - let expr_x = (self.expr_fn)(val_x, prod_x); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - }); - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - self.val.fold_round_in_place(r); - for b in self.bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - struct SparseShiftRemBoundOracle { bit_idx: usize, r_cycle: Vec, @@ -775,12 +540,6 @@ impl RoundOracle for SparseShiftRemBoundOracle { } let (b0s, b1s) = shamt_values_pair(&self.shamt_bits, child0, child1); - let mut r0s = Vec::with_capacity(self.rem_bits.len()); - let mut r1s = Vec::with_capacity(self.rem_bits.len()); - for b in self.rem_bits.iter() { - r0s.push(b.get(child0)); - r1s.push(b.get(child1)); - } let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); @@ -795,7 +554,9 @@ impl RoundOracle for SparseShiftRemBoundOracle { } let shamt = shamt_values_interp(&b0s, &b1s, x); - let expr_x = shift_rem_bound_expr(shamt, self.rem_bits.len(), |j| interp(r0s[j], r1s[j], x)); + let expr_x = shift_rem_bound_expr(shamt, self.rem_bits.len(), |j| { + interp(self.rem_bits[j].get(child0), self.rem_bits[j].get(child1), x) + }); if expr_x == K::ZERO { continue; @@ -929,12 +690,6 @@ impl RoundOracle for SparseShiftExprOracle { let sign1 = self.sign.as_ref().map(|s| s.get(child1)).unwrap_or(K::ZERO); let (b0s, b1s) = shamt_values_pair(&self.shamt_bits, child0, child1); - let mut l0s = Vec::with_capacity(self.bits.len()); - let mut l1s = Vec::with_capacity(self.bits.len()); - for b in self.bits.iter() { - l0s.push(b.get(child0)); - l1s.push(b.get(child1)); - } let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); @@ -956,8 +711,8 @@ impl RoundOracle for SparseShiftExprOracle { let pow2_x = pow2_from_shamt(&shamt_x); let mut limb_sum_x = K::ZERO; - for j in 0..self.bits.len() { - let bit_x = interp(l0s[j], l1s[j], x); + for (j, b) in self.bits.iter().enumerate() { + let bit_x = interp(b.get(child0), b.get(child1), x); limb_sum_x += bit_x * K::from_u64(1u64 << j); } @@ -1012,135 +767,7 @@ fn expr_rv32_packed_sra(lhs: K, val: K, pow2: K, rem: K, sign: K) -> K { lhs - val * pow2 - rem - sign * K::from_u64(1u64 << 32) * (K::ONE - pow2) } -type SparseBitsAndWeightsExprFn = fn(&[K; N], K, &[K; W]) -> K; - -struct SparseBitsAndWeightsExprOracle { - bit_idx: usize, - r_cycle: Vec, - prefix_eq: K, - has_lookup: SparseIdxVec, - cols: [SparseIdxVec; N], - bits: Vec>, - bit_weights: Vec, - expr_weights: [K; W], - degree_bound: usize, - expr_fn: SparseBitsAndWeightsExprFn, -} - -impl SparseBitsAndWeightsExprOracle { - fn new( - r_cycle: &[K], - has_lookup: SparseIdxVec, - cols: [SparseIdxVec; N], - bits: Vec>, - bit_weights: Vec, - expr_weights: [K; W], - degree_bound: usize, - expr_fn: SparseBitsAndWeightsExprFn, - ) -> Self { - let ell_n = r_cycle.len(); - debug_assert_eq!(has_lookup.len(), 1usize << ell_n); - assert_cols_match_time(&cols, 1usize << ell_n); - debug_assert_eq!(bits.len(), bit_weights.len()); - assert_cols_match_time(&bits, 1usize << ell_n); - - Self { - bit_idx: 0, - r_cycle: r_cycle.to_vec(), - prefix_eq: K::ONE, - has_lookup, - cols, - bits, - bit_weights, - expr_weights, - degree_bound, - expr_fn, - } - } -} - -impl RoundOracle for SparseBitsAndWeightsExprOracle { - fn evals_at(&mut self, points: &[K]) -> Vec { - if self.has_lookup.len() == 1 { - let gate = self.has_lookup.singleton_value(); - let cols: [K; N] = std::array::from_fn(|i| self.cols[i].singleton_value()); - let mut bit_sum = K::ZERO; - for (b, w) in self.bits.iter().zip(self.bit_weights.iter()) { - bit_sum += b.singleton_value() * *w; - } - let expr = (self.expr_fn)(&cols, bit_sum, &self.expr_weights); - let v = self.prefix_eq * gate * expr; - return vec![v; points.len()]; - } - - let mut ys = vec![K::ZERO; points.len()]; - for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { - let child0 = 2 * pair; - let child1 = child0 + 1; - - let gate0 = self.has_lookup.get(child0); - let gate1 = self.has_lookup.get(child1); - if gate0 == K::ZERO && gate1 == K::ZERO { - continue; - } - - let cols0: [K; N] = std::array::from_fn(|i| self.cols[i].get(child0)); - let cols1: [K; N] = std::array::from_fn(|i| self.cols[i].get(child1)); - let mut bit_sum0 = K::ZERO; - let mut bit_sum1 = K::ZERO; - for (b, w) in self.bits.iter().zip(self.bit_weights.iter()) { - bit_sum0 += b.get(child0) * *w; - bit_sum1 += b.get(child1) * *w; - } - - let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - - for (i, &x) in points.iter().enumerate() { - let chi_x = interp(chi0, chi1, x); - if chi_x == K::ZERO { - continue; - } - let gate_x = interp(gate0, gate1, x); - if gate_x == K::ZERO { - continue; - } - let cols_x: [K; N] = std::array::from_fn(|j| interp(cols0[j], cols1[j], x)); - let bit_sum_x = interp(bit_sum0, bit_sum1, x); - let expr_x = (self.expr_fn)(&cols_x, bit_sum_x, &self.expr_weights); - if expr_x == K::ZERO { - continue; - } - ys[i] += chi_x * gate_x * expr_x; - } - }); - ys - } - - fn num_rounds(&self) -> usize { - self.r_cycle.len().saturating_sub(self.bit_idx) - } - - fn degree_bound(&self) -> usize { - self.degree_bound - } - - fn fold(&mut self, r: K) { - if self.num_rounds() == 0 { - return; - } - self.prefix_eq *= eq_single(r, self.r_cycle[self.bit_idx]); - self.has_lookup.fold_round_in_place(r); - for col in self.cols.iter_mut() { - col.fold_round_in_place(r); - } - for b in self.bits.iter_mut() { - b.fold_round_in_place(r); - } - self.bit_idx += 1; - } -} - -type SparseColsBitsExprFn = fn(&[K; N], &[K], &[K; W]) -> K; +type SparseColsBitsExprFn = fn(&[K; N], &[K], K, &[K; W]) -> K; struct SparseColsBitsExprOracle { bit_idx: usize, @@ -1149,6 +776,8 @@ struct SparseColsBitsExprOracle { has_lookup: SparseIdxVec, cols: [SparseIdxVec; N], bits: Vec>, + bit_weights: Cow<'static, [K]>, + bits_x_scratch: Vec, expr_weights: [K; W], degree_bound: usize, expr_fn: SparseColsBitsExprFn, @@ -1160,6 +789,7 @@ impl SparseColsBitsExprOracle { has_lookup: SparseIdxVec, cols: [SparseIdxVec; N], bits: Vec>, + bit_weights: Cow<'static, [K]>, expr_weights: [K; W], degree_bound: usize, expr_fn: SparseColsBitsExprFn, @@ -1168,6 +798,7 @@ impl SparseColsBitsExprOracle { debug_assert_eq!(has_lookup.len(), 1usize << ell_n); assert_cols_match_time(&cols, 1usize << ell_n); assert_cols_match_time(&bits, 1usize << ell_n); + debug_assert!(bit_weights.is_empty() || bit_weights.len() == bits.len()); Self { bit_idx: 0, @@ -1175,7 +806,9 @@ impl SparseColsBitsExprOracle { prefix_eq: K::ONE, has_lookup, cols, + bits_x_scratch: vec![K::ZERO; bits.len()], bits, + bit_weights, expr_weights, degree_bound, expr_fn, @@ -1193,12 +826,23 @@ impl RoundOracle for SparseColsBitsExprOracle RoundOracle for SparseColsBitsExprOracle RoundOracle for SparseColsBitsExprOracle RoundOracle for SparseColsBitsExprOracle K { +fn expr_rv32_packed_mulh_adapter(cols: &[K; 7], _bits: &[K], _bit_sum: K, w: &[K; 2]) -> K { let lhs = cols[0]; let rhs = cols[1]; let lhs_sign = cols[2]; @@ -1284,7 +928,7 @@ fn expr_rv32_packed_mulh_adapter(cols: &[K; 7], _bit_sum: K, w: &[K; 2]) -> K { w[0] * eq_expr + w[1] * range } -fn expr_rv32_packed_divremu_adapter(cols: &[K; 4], bit_sum: K, w: &[K; 4]) -> K { +fn expr_rv32_packed_divremu_adapter(cols: &[K; 4], _bits: &[K], bit_sum: K, w: &[K; 4]) -> K { let rhs = cols[0]; let z = cols[1]; let rem = cols[2]; @@ -1296,7 +940,7 @@ fn expr_rv32_packed_divremu_adapter(cols: &[K; 4], bit_sum: K, w: &[K; 4]) -> K w[0] * c0 + w[1] * c1 + w[2] * c2 + w[3] * c3 } -fn expr_rv32_packed_divrem_adapter(cols: &[K; 10], bit_sum: K, w: &[K; 7]) -> K { +fn expr_rv32_packed_divrem_adapter(cols: &[K; 10], _bits: &[K], bit_sum: K, w: &[K; 7]) -> K { let lhs = cols[0]; let rhs = cols[1]; let z = cols[2]; @@ -1332,6 +976,22 @@ fn expr_u_decomp(cols: &[K; 1], sum: K) -> K { cols[0] - sum } +fn expr_rv32_packed_mul_bw0(cols: &[K; 3], _bits: &[K], bit_sum: K, _w: &[K; 0]) -> K { + expr_rv32_packed_mul(cols, bit_sum) +} + +fn expr_rv32_packed_mulhu_bw0(cols: &[K; 3], _bits: &[K], bit_sum: K, _w: &[K; 0]) -> K { + expr_rv32_packed_mulhu(cols, bit_sum) +} + +fn expr_rv32_packed_eq_adapter_bw0(cols: &[K; 3], _bits: &[K], bit_sum: K, _w: &[K; 0]) -> K { + expr_rv32_packed_eq_adapter(cols, bit_sum) +} + +fn expr_u_decomp_bw0(cols: &[K; 1], _bits: &[K], bit_sum: K, _w: &[K; 0]) -> K { + expr_u_decomp(cols, bit_sum) +} + fn expr_eq_from_prod(val: K, prod: K) -> K { val - prod } @@ -1340,6 +1000,22 @@ fn expr_neq_from_prod(val: K, prod: K) -> K { val + prod - K::ONE } +fn expr_eq_from_prod_bits(cols: &[K; 1], bits: &[K], _bit_sum: K, _w: &[K; 0]) -> K { + let mut prod = K::ONE; + for &b in bits { + prod *= K::ONE - b; + } + expr_eq_from_prod(cols[0], prod) +} + +fn expr_neq_from_prod_bits(cols: &[K; 1], bits: &[K], _bit_sum: K, _w: &[K; 0]) -> K { + let mut prod = K::ONE; + for &b in bits { + prod *= K::ONE - b; + } + expr_neq_from_prod(cols[0], prod) +} + macro_rules! define_sparse_time_expr_oracle { ($name:ident, $n:expr, $degree:expr, $expr:expr, [$($col:ident),+ $(,)?]) => { pub struct $name { @@ -1387,7 +1063,7 @@ define_sparse_time_expr_oracle!( macro_rules! define_weighted_bits32_oracle3 { ($name:ident, [$c0:ident, $c1:ident, $c2:ident], $expr:expr, $degree:expr) => { pub struct $name { - core: SparseWeightedBitsExprOracle<3>, + core: SparseColsBitsExprOracle<3, 0>, } impl $name { @@ -1400,14 +1076,14 @@ macro_rules! define_weighted_bits32_oracle3 { bits: Vec>, ) -> Self { debug_assert_eq!(bits.len(), 32); - let weights: Vec = (0..32).map(|i| K::from_u64(1u64 << i)).collect(); Self { - core: SparseWeightedBitsExprOracle::new( + core: SparseColsBitsExprOracle::new( r_cycle, has_lookup, [$c0, $c1, $c2], bits, - weights, + Cow::Borrowed(pow2_weights_32()), + [], $degree, $expr, ), @@ -1421,7 +1097,7 @@ macro_rules! define_weighted_bits32_oracle3 { macro_rules! define_weighted_bits32_oracle3_bits_before_last { ($name:ident, [$c0:ident, $c1:ident], $bits:ident, $c2:ident, $expr:expr, $degree:expr) => { pub struct $name { - core: SparseWeightedBitsExprOracle<3>, + core: SparseColsBitsExprOracle<3, 0>, } impl $name { @@ -1434,14 +1110,14 @@ macro_rules! define_weighted_bits32_oracle3_bits_before_last { $c2: SparseIdxVec, ) -> Self { debug_assert_eq!($bits.len(), 32); - let weights: Vec = (0..32).map(|i| K::from_u64(1u64 << i)).collect(); Self { - core: SparseWeightedBitsExprOracle::new( + core: SparseColsBitsExprOracle::new( r_cycle, has_lookup, [$c0, $c1, $c2], $bits, - weights, + Cow::Borrowed(pow2_weights_32()), + [], $degree, $expr, ), @@ -1457,7 +1133,7 @@ define_weighted_bits32_oracle3_bits_before_last!( [lhs, rhs], carry_bits, val, - expr_rv32_packed_mul, + expr_rv32_packed_mul_bw0, 4 ); @@ -1466,7 +1142,7 @@ define_weighted_bits32_oracle3_bits_before_last!( [lhs, rhs], lo_bits, val, - expr_rv32_packed_mulhu, + expr_rv32_packed_mulhu_bw0, 4 ); @@ -1475,12 +1151,12 @@ define_weighted_bits32_oracle3_bits_before_last!( [lhs, rhs], lo_bits, hi, - expr_rv32_packed_mulhu, + expr_rv32_packed_mulhu_bw0, 4 ); pub struct Rv32PackedMulhAdapterOracleSparseTime { - core: SparseBitsAndWeightsExprOracle<7, 2>, + core: SparseColsBitsExprOracle<7, 2>, } impl Rv32PackedMulhAdapterOracleSparseTime { @@ -1497,12 +1173,12 @@ impl Rv32PackedMulhAdapterOracleSparseTime { weights: [K; 2], ) -> Self { Self { - core: SparseBitsAndWeightsExprOracle::new( + core: SparseColsBitsExprOracle::new( r_cycle, has_lookup, [lhs, rhs, lhs_sign, rhs_sign, hi, k, val], Vec::new(), - Vec::new(), + Cow::Borrowed(&[]), weights, 5, expr_rv32_packed_mulh_adapter, @@ -1523,7 +1199,7 @@ define_sparse_time_expr_oracle!( macro_rules! define_val_prod_bits_oracle { ($name:ident, $expr:expr) => { pub struct $name { - core: SparseValProdBitsOracle, + core: SparseColsBitsExprOracle<1, 0>, } impl $name { @@ -1535,7 +1211,16 @@ macro_rules! define_val_prod_bits_oracle { ) -> Self { debug_assert_eq!(diff_bits.len(), 32); Self { - core: SparseValProdBitsOracle::new(r_cycle, has_lookup, val, diff_bits, 34, $expr), + core: SparseColsBitsExprOracle::new( + r_cycle, + has_lookup, + [val], + diff_bits, + Cow::Borrowed(&[]), + [], + 34, + $expr, + ), } } } @@ -1543,21 +1228,21 @@ macro_rules! define_val_prod_bits_oracle { }; } -define_val_prod_bits_oracle!(Rv32PackedEqOracleSparseTime, expr_eq_from_prod); +define_val_prod_bits_oracle!(Rv32PackedEqOracleSparseTime, expr_eq_from_prod_bits); define_weighted_bits32_oracle3!( Rv32PackedEqAdapterOracleSparseTime, [lhs, rhs, borrow], - expr_rv32_packed_eq_adapter, + expr_rv32_packed_eq_adapter_bw0, 3 ); -define_val_prod_bits_oracle!(Rv32PackedNeqOracleSparseTime, expr_neq_from_prod); +define_val_prod_bits_oracle!(Rv32PackedNeqOracleSparseTime, expr_neq_from_prod_bits); define_weighted_bits32_oracle3!( Rv32PackedNeqAdapterOracleSparseTime, [lhs, rhs, borrow], - expr_rv32_packed_eq_adapter, + expr_rv32_packed_eq_adapter_bw0, 3 ); @@ -1692,7 +1377,7 @@ define_sparse_time_expr_oracle!( macro_rules! define_bits_and_weights32_oracle { ($name:ident, $n:expr, $m:expr, [$($col:ident),+ $(,)?], $degree:expr, $expr:expr) => { pub struct $name { - core: SparseBitsAndWeightsExprOracle<$n, $m>, + core: SparseColsBitsExprOracle<$n, $m>, } impl $name { @@ -1704,14 +1389,13 @@ macro_rules! define_bits_and_weights32_oracle { weights: [K; $m], ) -> Self { debug_assert_eq!(diff_bits.len(), 32); - let bit_weights: Vec = (0..32).map(|i| K::from_u64(1u64 << i)).collect(); Self { - core: SparseBitsAndWeightsExprOracle::new( + core: SparseColsBitsExprOracle::new( r_cycle, has_lookup, [$($col),+], diff_bits, - bit_weights, + Cow::Borrowed(pow2_weights_32()), weights, $degree, $expr, @@ -1824,7 +1508,7 @@ fn rv32_digit4_range_poly(x: K) -> K { x * (x - K::ONE) * (x - K::from_u64(2)) * (x - K::from_u64(3)) } -fn expr_rv32_packed_bitwise_adapter(cols: &[K; 2], bits: &[K], w: &[K; 34]) -> K { +fn expr_rv32_packed_bitwise_adapter(cols: &[K; 2], bits: &[K], _bit_sum: K, w: &[K; 34]) -> K { let lhs = cols[0]; let rhs = cols[1]; debug_assert_eq!(bits.len(), 32); @@ -1873,6 +1557,7 @@ impl Rv32PackedBitwiseAdapterOracleSparseTime { has_lookup, [lhs, rhs], digits, + Cow::Borrowed(&[]), expr_weights, 6, expr_rv32_packed_bitwise_adapter, @@ -1882,7 +1567,7 @@ impl Rv32PackedBitwiseAdapterOracleSparseTime { } impl_round_oracle_via_core!(Rv32PackedBitwiseAdapterOracleSparseTime); -fn expr_rv32_packed_bitwise(cols: &[K; 1], bits: &[K], w: &[K; 2], op: Rv32PackedBitwiseOp2) -> K { +fn expr_rv32_packed_bitwise(cols: &[K; 1], bits: &[K], _bit_sum: K, w: &[K; 2], op: Rv32PackedBitwiseOp2) -> K { let val = cols[0]; debug_assert_eq!(bits.len(), 32); let mut out = K::ZERO; @@ -1895,20 +1580,20 @@ fn expr_rv32_packed_bitwise(cols: &[K; 1], bits: &[K], w: &[K; 2], op: Rv32Packe out - val } -fn expr_rv32_packed_and(cols: &[K; 1], bits: &[K], w: &[K; 2]) -> K { - expr_rv32_packed_bitwise(cols, bits, w, Rv32PackedBitwiseOp2::And) +fn expr_rv32_packed_and(cols: &[K; 1], bits: &[K], bit_sum: K, w: &[K; 2]) -> K { + expr_rv32_packed_bitwise(cols, bits, bit_sum, w, Rv32PackedBitwiseOp2::And) } -fn expr_rv32_packed_andn(cols: &[K; 1], bits: &[K], w: &[K; 2]) -> K { - expr_rv32_packed_bitwise(cols, bits, w, Rv32PackedBitwiseOp2::Andn) +fn expr_rv32_packed_andn(cols: &[K; 1], bits: &[K], bit_sum: K, w: &[K; 2]) -> K { + expr_rv32_packed_bitwise(cols, bits, bit_sum, w, Rv32PackedBitwiseOp2::Andn) } -fn expr_rv32_packed_or(cols: &[K; 1], bits: &[K], w: &[K; 2]) -> K { - expr_rv32_packed_bitwise(cols, bits, w, Rv32PackedBitwiseOp2::Or) +fn expr_rv32_packed_or(cols: &[K; 1], bits: &[K], bit_sum: K, w: &[K; 2]) -> K { + expr_rv32_packed_bitwise(cols, bits, bit_sum, w, Rv32PackedBitwiseOp2::Or) } -fn expr_rv32_packed_xor(cols: &[K; 1], bits: &[K], w: &[K; 2]) -> K { - expr_rv32_packed_bitwise(cols, bits, w, Rv32PackedBitwiseOp2::Xor) +fn expr_rv32_packed_xor(cols: &[K; 1], bits: &[K], bit_sum: K, w: &[K; 2]) -> K { + expr_rv32_packed_bitwise(cols, bits, bit_sum, w, Rv32PackedBitwiseOp2::Xor) } macro_rules! define_rv32_packed_bitwise_oracle { @@ -1931,7 +1616,16 @@ macro_rules! define_rv32_packed_bitwise_oracle { let inv2 = K::from_u64(2).inverse(); let inv6 = K::from_u64(6).inverse(); Self { - core: SparseColsBitsExprOracle::new(r_cycle, has_lookup, [val], digits, [inv2, inv6], 8, $expr), + core: SparseColsBitsExprOracle::new( + r_cycle, + has_lookup, + [val], + digits, + Cow::Borrowed(&[]), + [inv2, inv6], + 8, + $expr, + ), } } } @@ -1947,7 +1641,7 @@ define_rv32_packed_bitwise_oracle!(Rv32PackedXorOracleSparseTime, expr_rv32_pack macro_rules! define_u_decomp_oracle { ($name:ident, $num_bits:expr) => { pub struct $name { - core: SparseWeightedBitsExprOracle<1>, + core: SparseColsBitsExprOracle<1, 0>, } impl $name { @@ -1958,9 +1652,22 @@ macro_rules! define_u_decomp_oracle { bits: Vec>, ) -> Self { debug_assert_eq!(bits.len(), $num_bits); - let weights: Vec = (0..$num_bits).map(|i| K::from_u64(1u64 << i)).collect(); + let weights: Cow<'static, [K]> = match $num_bits { + 32 => Cow::Borrowed(pow2_weights_32()), + 5 => Cow::Borrowed(pow2_weights_5()), + _ => Cow::Owned((0..$num_bits).map(|i| K::from_u64(1u64 << i)).collect()), + }; Self { - core: SparseWeightedBitsExprOracle::new(r_cycle, has_lookup, [x], bits, weights, 3, expr_u_decomp), + core: SparseColsBitsExprOracle::new( + r_cycle, + has_lookup, + [x], + bits, + weights, + [], + 3, + expr_u_decomp_bw0, + ), } } } @@ -2095,22 +1802,26 @@ fn accumulate_pair_with_eq_addr_over_points( r_addr: &[K], child0: usize, child1: usize, + eq0s_scratch: &mut Vec, + d_eqs_scratch: &mut Vec, mut coeff_at: F, ) where F: FnMut(K) -> K, { debug_assert_eq!(bit_cols.len(), r_addr.len()); - let mut eq0s = Vec::with_capacity(bit_cols.len()); - let mut d_eqs = Vec::with_capacity(bit_cols.len()); + eq0s_scratch.clear(); + d_eqs_scratch.clear(); + eq0s_scratch.reserve(bit_cols.len()); + d_eqs_scratch.reserve(bit_cols.len()); for (b, col) in bit_cols.iter().enumerate() { let e0 = eq_bit_affine(col.get(child0), r_addr[b]); - eq0s.push(e0); - d_eqs.push(eq_bit_affine(col.get(child1), r_addr[b]) - e0); + eq0s_scratch.push(e0); + d_eqs_scratch.push(eq_bit_affine(col.get(child1), r_addr[b]) - e0); } for (i, &x) in points.iter().enumerate() { let mut eq_addr = K::ONE; - for (e0, de) in eq0s.iter().zip(d_eqs.iter()) { + for (e0, de) in eq0s_scratch.iter().zip(d_eqs_scratch.iter()) { eq_addr *= *e0 + *de * x; } ys[i] += coeff_at(x) * eq_addr; @@ -2179,6 +1890,8 @@ impl RoundOracle for IndexAdapterOracleSparseTime { } let mut ys = vec![K::ZERO; points.len()]; + let mut eq0s_scratch = Vec::with_capacity(self.addr_bits.len()); + let mut d_eqs_scratch = Vec::with_capacity(self.addr_bits.len()); for_each_sparse_parent_pair!(self.has_lookup.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -2190,9 +1903,17 @@ impl RoundOracle for IndexAdapterOracleSparseTime { } let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - accumulate_pair_with_eq_addr_over_points(&mut ys, points, &self.addr_bits, &self.r_addr, child0, child1, |x| { - interp(chi0, chi1, x) * interp(gate0, gate1, x) - }); + accumulate_pair_with_eq_addr_over_points( + &mut ys, + points, + &self.addr_bits, + &self.r_addr, + child0, + child1, + &mut eq0s_scratch, + &mut d_eqs_scratch, + |x| interp(chi0, chi1, x) * interp(gate0, gate1, x), + ); }); ys } @@ -2409,6 +2130,8 @@ impl RoundOracle for TwistTimeCheckOracleSparseTimeCore { } let mut ys = vec![K::ZERO; points.len()]; + let mut eq0s_scratch = Vec::with_capacity(self.addr_bits.len()); + let mut d_eqs_scratch = Vec::with_capacity(self.addr_bits.len()); for_each_sparse_parent_pair!(self.gate.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -2448,9 +2171,17 @@ impl RoundOracle for TwistTimeCheckOracleSparseTimeCore { }; let (chi0, chi1) = chi_cycle_children(&self.r_cycle, self.bit_idx, self.prefix_eq, pair); - accumulate_pair_with_eq_addr_over_points(&mut ys, points, &self.addr_bits, &self.r_addr, child0, child1, |x| { - interp(chi0, chi1, x) * interp(gate0, gate1, x) * interp(term0, term1, x) - }); + accumulate_pair_with_eq_addr_over_points( + &mut ys, + points, + &self.addr_bits, + &self.r_addr, + child0, + child1, + &mut eq0s_scratch, + &mut d_eqs_scratch, + |x| interp(chi0, chi1, x) * interp(gate0, gate1, x) * interp(term0, term1, x), + ); }); ys } @@ -2718,6 +2449,8 @@ impl RoundOracle for TwistWriteEqAddrOracleSparseTimeCore { } let mut ys = vec![K::ZERO; points.len()]; + let mut eq0s_scratch = Vec::with_capacity(self.wa_bits.len()); + let mut d_eqs_scratch = Vec::with_capacity(self.wa_bits.len()); for_each_sparse_parent_pair!(self.has_write.entries(), pair, { let child0 = 2 * pair; let child1 = child0 + 1; @@ -2744,9 +2477,17 @@ impl RoundOracle for TwistWriteEqAddrOracleSparseTimeCore { } }; - accumulate_pair_with_eq_addr_over_points(&mut ys, points, &self.wa_bits, &self.r_addr, child0, child1, |x| { - interp(gate0, gate1, x) * interp(inc0, inc1, x) * interp(lt0, lt1, x) - }); + accumulate_pair_with_eq_addr_over_points( + &mut ys, + points, + &self.wa_bits, + &self.r_addr, + child0, + child1, + &mut eq0s_scratch, + &mut d_eqs_scratch, + |x| interp(gate0, gate1, x) * interp(inc0, inc1, x) * interp(lt0, lt1, x), + ); }); ys } @@ -3104,7 +2845,7 @@ fn push_write_events( fn assert_init_sparse_in_range(init_sparse: &[(usize, K)], ell_addr: usize) { let pow2_addr = 1usize << ell_addr; for (addr, _) in init_sparse.iter() { - assert!(*addr < pow2_addr); + debug_assert!(*addr < pow2_addr); } } @@ -3166,7 +2907,7 @@ fn collect_multilane_addr_events( lanes: &[TwistLaneSparseCols], mode: AddrCheckMode, ) -> (usize, Vec) { - assert!(!lanes.is_empty()); + debug_assert!(!lanes.is_empty()); let pow2_time = 1usize << r_cycle.len(); let ell_addr = match mode { AddrCheckMode::Read => lanes[0].ra_bits.len(), @@ -3320,11 +3061,11 @@ impl AddressLookupOracle { let pow2_cycle = 1usize << r_cycle.len(); let pow2_addr = 1usize << ell_addr; - assert_eq!(addr_bits.len(), ell_addr); + debug_assert_eq!(addr_bits.len(), ell_addr); for col in addr_bits.iter() { - assert_eq!(col.len(), pow2_cycle); + debug_assert_eq!(col.len(), pow2_cycle); } - assert_eq!(has_lookup.len(), pow2_cycle); + debug_assert_eq!(has_lookup.len(), pow2_cycle); let mut claimed_sum = K::ZERO; let mut weight_table = vec![K::ZERO; pow2_addr];