Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions crates/neo-ajtai/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use p3_field::PrimeCharacteristicRing;
use p3_goldilocks::Goldilocks as Fq;
use serde::{Deserialize, Serialize};
use serde::de::Error as _;
use serde::{Deserialize, Deserializer, Serialize};

/// Public parameters for Ajtai: M ∈ R_q^{κ×m}, stored row-major.
#[derive(Clone, Debug, Serialize, Deserialize)]
Expand All @@ -13,7 +14,7 @@ pub struct PP<RqEl> {
}

/// Commitment c ∈ F_q^{d×κ}, stored as column-major flat matrix (κ columns, each length d).
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
pub struct Commitment {
pub d: usize,
pub kappa: usize,
Expand All @@ -22,6 +23,26 @@ pub struct Commitment {
}

impl Commitment {
#[inline]
fn validate_shape(d: usize, kappa: usize, data_len: usize) -> Result<(), String> {
let expected_d = neo_math::ring::D;
if d != expected_d {
return Err(format!("invalid Commitment.d: expected {expected_d}, got {d}"));
}

let expected_len = d
.checked_mul(kappa)
.ok_or_else(|| format!("invalid Commitment shape: d*kappa overflow (d={d}, kappa={kappa})"))?;
if data_len != expected_len {
return Err(format!(
"invalid Commitment shape: data.len()={} but d*kappa={expected_len}",
data_len
));
}

Ok(())
}

pub fn zeros(d: usize, kappa: usize) -> Self {
Self {
d,
Expand All @@ -48,3 +69,26 @@ impl Commitment {
}
}
}

impl<'de> Deserialize<'de> for Commitment {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct CommitmentWire {
d: usize,
kappa: usize,
data: Vec<Fq>,
}

let wire = CommitmentWire::deserialize(deserializer)?;
Commitment::validate_shape(wire.d, wire.kappa, wire.data.len()).map_err(D::Error::custom)?;

Ok(Self {
d: wire.d,
kappa: wire.kappa,
data: wire.data,
})
}
}
41 changes: 41 additions & 0 deletions crates/neo-ajtai/tests/commitment_deserialize_invariants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use neo_ajtai::Commitment;
use neo_math::ring::D;
use serde_json::Value;

fn valid_commitment_json(kappa: usize) -> Value {
let c = Commitment::zeros(D, kappa);
serde_json::to_value(&c).expect("serialize valid commitment")
}

#[test]
fn commitment_deserialize_rejects_wrong_d() {
let mut value = valid_commitment_json(2);
value["d"] = serde_json::json!(D - 1);

let err = serde_json::from_value::<Commitment>(value).expect_err("wrong d must be rejected");
let msg = err.to_string();
assert!(msg.contains("invalid Commitment.d"), "unexpected error message: {msg}");
}

#[test]
fn commitment_deserialize_rejects_data_len_mismatch() {
let mut value = valid_commitment_json(2);
let data = value
.get_mut("data")
.and_then(Value::as_array_mut)
.expect("commitment.data array");
data.pop().expect("non-empty data");

let err = serde_json::from_value::<Commitment>(value).expect_err("shape mismatch must be rejected");
let msg = err.to_string();
assert!(msg.contains("data.len()"), "unexpected error message: {msg}");
}

#[test]
fn commitment_deserialize_accepts_valid_shape() {
let value = valid_commitment_json(3);
let c: Commitment = serde_json::from_value(value).expect("valid shape should deserialize");
assert_eq!(c.d, D);
assert_eq!(c.kappa, 3);
assert_eq!(c.data.len(), D * 3);
}
46 changes: 30 additions & 16 deletions crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ pub(crate) fn build_route_a_control_time_claims(
trace.rd_val,
trace.shout_val,
trace.jalr_drop_bit,
trace.pc_carry,
];
let decode_col_ids = vec![
decode.op_lui,
Expand Down Expand Up @@ -575,29 +576,42 @@ pub(crate) fn build_route_a_control_time_claims(
);

let control_sparse = vec![
main_col(trace.active)?,
main_col(trace.pc_before)?,
main_col(trace.pc_after)?,
main_col(trace.rs1_val)?,
main_col(trace.jalr_drop_bit)?,
main_col(trace.shout_val)?,
decode_col(decode.funct3_bit[0])?,
decode_col(decode.op_jal)?,
decode_col(decode.op_jalr)?,
decode_col(decode.op_branch)?,
decode_col(decode.imm_i)?,
decode_col(decode.imm_b)?,
decode_col(decode.imm_j)?,
main_col(trace.active)?, // 0
main_col(trace.pc_before)?, // 1
main_col(trace.pc_after)?, // 2
main_col(trace.rs1_val)?, // 3
main_col(trace.jalr_drop_bit)?, // 4
main_col(trace.pc_carry)?, // 5
main_col(trace.shout_val)?, // 6
decode_col(decode.funct3_bit[0])?,// 7
decode_col(decode.op_jal)?, // 8
decode_col(decode.op_jalr)?, // 9
decode_col(decode.op_branch)?, // 10
decode_col(decode.imm_i)?, // 11
decode_col(decode.imm_b)?, // 12
decode_col(decode.imm_j)?, // 13
];
let control_weights = control_next_pc_control_weight_vector(r_cycle, 5);
let control_weights = control_next_pc_control_weight_vector(r_cycle, 7);
let control_oracle = FormulaOracleSparseTime::new(
control_sparse,
5,
r_cycle,
Box::new(move |vals: &[K]| {
let residuals = control_next_pc_control_residuals(
vals[0], vals[1], vals[2], vals[3], vals[4], vals[10], vals[11], vals[12], vals[7], vals[8], vals[9],
vals[5], vals[6],
vals[0], // active
vals[1], // pc_before
vals[2], // pc_after
vals[3], // rs1_val
vals[4], // jalr_drop_bit
vals[5], // pc_carry
vals[11], // imm_i
vals[12], // imm_b
vals[13], // imm_j
vals[8], // op_jal
vals[9], // op_jalr
vals[10], // op_branch
vals[6], // shout_val
vals[7], // funct3_bit0
);
let mut weighted = K::ZERO;
for (r, w) in residuals.iter().zip(control_weights.iter()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ pub(crate) fn verify_route_a_control_terminals(
let rs1_val = wp_open_col(trace.rs1_val)?;
let rd_val = wp_open_col(trace.rd_val)?;
let jalr_drop_bit = wp_open_col(trace.jalr_drop_bit)?;
let pc_carry = wp_open_col(trace.pc_carry)?;
let shout_val = wp_open_col(trace.shout_val)?;
let funct3_bits = [
decode_open_col(decode.funct3_bit[0])?,
Expand Down Expand Up @@ -756,6 +757,7 @@ pub(crate) fn verify_route_a_control_terminals(
pc_after,
rs1_val,
jalr_drop_bit,
pc_carry,
imm_i,
imm_b,
imm_j,
Expand Down
33 changes: 23 additions & 10 deletions crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ pub(crate) fn rv32_trace_wb_columns(layout: &Rv32TraceLayout) -> Vec<usize> {
vec![layout.active, layout.halted, layout.shout_has_lookup]
}

pub(crate) const W2_FIELDS_RESIDUAL_COUNT: usize = 70;
pub(crate) const W2_FIELDS_RESIDUAL_COUNT: usize = 76;
pub(crate) const W2_IMM_RESIDUAL_COUNT: usize = 4;

#[inline]
Expand Down Expand Up @@ -929,12 +929,14 @@ pub(crate) fn w2_alu_branch_lookup_residuals(
rs2_decode: K,
imm_i: K,
imm_s: K,
) -> [K; 42] {
) -> [K; 48] {
let op_lui = opcode_flags[0];
let op_auipc = opcode_flags[1];
let op_jal = opcode_flags[2];
let op_jalr = opcode_flags[3];
let op_branch = opcode_flags[4];
let op_load = opcode_flags[5];
let op_store = opcode_flags[6];
let op_alu_imm = opcode_flags[7];
let op_alu_reg = opcode_flags[8];
let op_misc_mem = opcode_flags[9];
Expand Down Expand Up @@ -966,7 +968,7 @@ pub(crate) fn w2_alu_branch_lookup_residuals(
op_alu_reg * (shout_has_lookup - K::ONE),
op_branch * (shout_has_lookup - K::ONE),
(K::ONE - shout_has_lookup) * shout_table_id,
(op_alu_imm + op_alu_reg + op_branch) * (shout_lhs - rs1_val),
(op_alu_imm + op_alu_reg + op_branch + op_load + op_store) * (shout_lhs - rs1_val),
alu_imm_shift_rhs_delta - shift_selector * (rs2_decode - imm_i),
op_alu_imm * (shout_rhs - imm_i - alu_imm_shift_rhs_delta),
op_alu_reg * (shout_rhs - rs2_val),
Expand Down Expand Up @@ -996,14 +998,20 @@ pub(crate) fn w2_alu_branch_lookup_residuals(
opcode_flags[6] * rd_has_write,
op_misc_mem * rd_has_write,
op_system * rd_has_write,
active * (halted - op_system),
active * halted * (K::ONE - op_system),
opcode_flags[5] * (ram_has_read - K::ONE),
opcode_flags[6] * (ram_has_write - K::ONE),
non_mem_ops * ram_has_read,
non_mem_ops * ram_has_write,
non_mem_ops * ram_addr,
opcode_flags[5] * (ram_addr - rs1_val - imm_i),
opcode_flags[6] * (ram_addr - rs1_val - imm_s),
op_load * (ram_addr - shout_val),
op_store * (ram_addr - shout_val),
op_load * (shout_has_lookup - K::ONE),
op_store * (shout_has_lookup - K::ONE),
op_load * (shout_rhs - imm_i),
op_store * (shout_rhs - imm_s),
op_load * (shout_table_id - K::from(F::from_u64(3))),
op_store * (shout_table_id - K::from(F::from_u64(3))),
]
}

Expand Down Expand Up @@ -1256,6 +1264,7 @@ pub(crate) fn control_next_pc_control_residuals(
pc_after: K,
rs1_val: K,
jalr_drop_bit: K,
pc_carry: K,
imm_i: K,
imm_b: K,
imm_j: K,
Expand All @@ -1264,15 +1273,18 @@ pub(crate) fn control_next_pc_control_residuals(
op_branch: K,
shout_val: K,
funct3_bit0: K,
) -> [K; 5] {
) -> [K; 7] {
let four = K::from(F::from_u64(4));
let two32 = K::from(F::from_u64(1u64 << 32));
let taken = control_branch_taken_from_bits(shout_val, funct3_bit0);
[
op_jal * (pc_after - pc_before - imm_j),
op_jalr * (pc_after - rs1_val - imm_i + jalr_drop_bit),
op_branch * (pc_after - pc_before - four - taken * (imm_b - four)),
op_jal * (pc_after + pc_carry * two32 - pc_before - imm_j),
op_jalr * (pc_after + pc_carry * two32 - rs1_val - imm_i + jalr_drop_bit),
op_branch * (pc_after + pc_carry * two32 - pc_before - four - taken * (imm_b - four)),
op_jalr * jalr_drop_bit * (jalr_drop_bit - K::ONE),
(active - op_jalr) * jalr_drop_bit,
pc_carry * (pc_carry - K::ONE),
(active - op_jal - op_jalr - op_branch) * pc_carry,
]
}

Expand Down Expand Up @@ -1328,6 +1340,7 @@ pub(crate) fn rv32_trace_wp_columns(layout: &Rv32TraceLayout) -> Vec<usize> {
layout.shout_lhs,
layout.shout_rhs,
layout.jalr_drop_bit,
layout.pc_carry,
]
}

Expand Down
4 changes: 2 additions & 2 deletions crates/neo-fold/src/memory_sidecar/route_a_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ pub fn prove_route_a_batched_time(
let proof = BatchedTimeProof {
claimed_sums: claimed_sums.clone(),
degree_bounds: degree_bounds.clone(),
labels: labels.clone(),
labels: labels.iter().map(|label| label.to_vec()).collect(),
round_polys: per_claim_results
.iter()
.map(|r| r.round_polys.clone())
Expand Down Expand Up @@ -449,7 +449,7 @@ pub fn verify_route_a_batched_time(
)));
}
for (i, (got, exp)) in proof.labels.iter().zip(expected_labels.iter()).enumerate() {
if (*got as &[u8]) != *exp {
if got.as_slice() != *exp {
return Err(PiCcsError::ProtocolError(format!(
"step {}: batched_time label mismatch at claim {}",
step_idx, i
Expand Down
Loading