Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions crates/neo-ajtai/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use p3_field::PrimeCharacteristicRing;
use p3_goldilocks::Goldilocks as Fq;
use serde::{Deserialize, Serialize};
use serde::de::Error as _;
use serde::{Deserialize, Deserializer, Serialize};

/// Public parameters for Ajtai: M ∈ R_q^{κ×m}, stored row-major.
#[derive(Clone, Debug, Serialize, Deserialize)]
Expand All @@ -13,7 +14,7 @@ pub struct PP<RqEl> {
}

/// Commitment c ∈ F_q^{d×κ}, stored as column-major flat matrix (κ columns, each length d).
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
pub struct Commitment {
pub d: usize,
pub kappa: usize,
Expand All @@ -22,6 +23,26 @@ pub struct Commitment {
}

impl Commitment {
#[inline]
fn validate_shape(d: usize, kappa: usize, data_len: usize) -> Result<(), String> {
let expected_d = neo_math::ring::D;
if d != expected_d {
return Err(format!("invalid Commitment.d: expected {expected_d}, got {d}"));
}

let expected_len = d
.checked_mul(kappa)
.ok_or_else(|| format!("invalid Commitment shape: d*kappa overflow (d={d}, kappa={kappa})"))?;
if data_len != expected_len {
return Err(format!(
"invalid Commitment shape: data.len()={} but d*kappa={expected_len}",
data_len
));
}

Ok(())
}

pub fn zeros(d: usize, kappa: usize) -> Self {
Self {
d,
Expand All @@ -48,3 +69,26 @@ impl Commitment {
}
}
}

impl<'de> Deserialize<'de> for Commitment {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct CommitmentWire {
d: usize,
kappa: usize,
data: Vec<Fq>,
}

let wire = CommitmentWire::deserialize(deserializer)?;
Commitment::validate_shape(wire.d, wire.kappa, wire.data.len()).map_err(D::Error::custom)?;

Ok(Self {
d: wire.d,
kappa: wire.kappa,
data: wire.data,
})
}
}
41 changes: 41 additions & 0 deletions crates/neo-ajtai/tests/commitment_deserialize_invariants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use neo_ajtai::Commitment;
use neo_math::ring::D;
use serde_json::Value;

fn valid_commitment_json(kappa: usize) -> Value {
let c = Commitment::zeros(D, kappa);
serde_json::to_value(&c).expect("serialize valid commitment")
}

#[test]
fn commitment_deserialize_rejects_wrong_d() {
let mut value = valid_commitment_json(2);
value["d"] = serde_json::json!(D - 1);

let err = serde_json::from_value::<Commitment>(value).expect_err("wrong d must be rejected");
let msg = err.to_string();
assert!(msg.contains("invalid Commitment.d"), "unexpected error message: {msg}");
}

#[test]
fn commitment_deserialize_rejects_data_len_mismatch() {
let mut value = valid_commitment_json(2);
let data = value
.get_mut("data")
.and_then(Value::as_array_mut)
.expect("commitment.data array");
data.pop().expect("non-empty data");

let err = serde_json::from_value::<Commitment>(value).expect_err("shape mismatch must be rejected");
let msg = err.to_string();
assert!(msg.contains("data.len()"), "unexpected error message: {msg}");
}

#[test]
fn commitment_deserialize_accepts_valid_shape() {
let value = valid_commitment_json(3);
let c: Commitment = serde_json::from_value(value).expect("valid shape should deserialize");
assert_eq!(c.d, D);
assert_eq!(c.kappa, 3);
assert_eq!(c.data.len(), D * 3);
}
4 changes: 2 additions & 2 deletions crates/neo-fold/src/memory_sidecar/route_a_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ pub fn prove_route_a_batched_time(
let proof = BatchedTimeProof {
claimed_sums: claimed_sums.clone(),
degree_bounds: degree_bounds.clone(),
labels: labels.clone(),
labels: labels.iter().map(|label| label.to_vec()).collect(),
round_polys: per_claim_results
.iter()
.map(|r| r.round_polys.clone())
Expand Down Expand Up @@ -223,7 +223,7 @@ pub fn verify_route_a_batched_time(
)));
}
for (i, (got, exp)) in proof.labels.iter().zip(expected_labels.iter()).enumerate() {
if (*got as &[u8]) != *exp {
if got.as_slice() != *exp {
return Err(PiCcsError::ProtocolError(format!(
"step {}: batched_time label mismatch at claim {}",
step_idx, i
Expand Down
Loading