From 4850c38afe2056a251b21f39dc0dc999f75e8d08 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 01:04:30 +0400 Subject: [PATCH 01/29] wip --- crates/xmss/src/lib.rs | 11 +-- crates/xmss/src/phony_xmss.rs | 78 --------------- crates/xmss/src/wots.rs | 75 +++++++++++--- crates/xmss/src/xmss.rs | 167 +++++++++++++++++--------------- crates/xmss/tests/xmss_tests.rs | 18 ++++ 5 files changed, 172 insertions(+), 177 deletions(-) delete mode 100644 crates/xmss/src/phony_xmss.rs create mode 100644 crates/xmss/tests/xmss_tests.rs diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 6237f163..d6c2ad85 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -10,11 +10,11 @@ use utils::poseidon16_compress; pub use wots::*; mod xmss; pub use xmss::*; -mod phony_xmss; -pub use phony_xmss::*; + +pub(crate) const DIGEST_SIZE: usize = 8; type F = KoalaBear; -type Digest = [F; 8]; +type Digest = [F; DIGEST_SIZE]; // WOTS pub const V: usize = 66; @@ -22,8 +22,7 @@ pub const W: usize = 4; pub const CHAIN_LENGTH: usize = 1 << W; pub const D: usize = 80; pub const TARGET_SUM: usize = V * (W - 1) - D; -pub const XMSS_MIN_LOG_LIFETIME: usize = 2; -pub const XMSS_MAX_LOG_LIFETIME: usize = 30; +pub const LOG_LIFETIME: usize = 32; pub type Poseidon16History = Vec<([F; 16], [F; 8])>; @@ -31,5 +30,5 @@ fn poseidon16_compress_with_trace(a: &Digest, b: &Digest, poseidon_16_trace: &mu let input: [F; 16] = [*a, *b].concat().try_into().unwrap(); let output = poseidon16_compress(input); poseidon_16_trace.push((input, output)); - output[0..8].try_into().unwrap() + output[0..DIGEST_SIZE].try_into().unwrap() } diff --git a/crates/xmss/src/phony_xmss.rs b/crates/xmss/src/phony_xmss.rs deleted file mode 100644 index 302008b5..00000000 --- a/crates/xmss/src/phony_xmss.rs +++ /dev/null @@ -1,78 +0,0 @@ -use crate::*; -use multilinear_toolkit::prelude::*; -use rand::{Rng, SeedableRng, rngs::StdRng}; - -// Only 1 WOTS, everything else in the merkle tree is random -// Useful for benchmark with a big lifetime, to speed up keys generation - -#[derive(Debug)] -struct PhonyXmssSecretKey { - wots_secret_key: WotsSecretKey, - first_slot: u64, - signature_slot: u64, - merkle_path: Vec, - public_key: XmssPublicKey, -} - -impl PhonyXmssSecretKey { - fn random(rng: &mut impl Rng, first_slot: u64, log_lifetime: usize, signature_slot: u64) -> Self { - assert!( - signature_slot.checked_sub(first_slot).unwrap() < (1 << log_lifetime), - "Index out of bounds for XMSS signature" - ); - let wots_secret_key = WotsSecretKey::random(rng); - let mut merkle_path = Vec::new(); - let mut hash = wots_secret_key.public_key().hash(); - let wots_index = signature_slot - first_slot; - for i in 0..log_lifetime { - let phony_neighbour: Digest = rng.random(); - let is_left = (wots_index >> i).is_multiple_of(2); - if is_left { - hash = poseidon16_compress([hash, phony_neighbour].concat().try_into().unwrap()); - } else { - hash = poseidon16_compress([phony_neighbour, hash].concat().try_into().unwrap()); - }; - merkle_path.push(phony_neighbour); - } - Self { - wots_secret_key, - first_slot, - signature_slot, - merkle_path, - public_key: XmssPublicKey { - merkle_root: hash, - log_lifetime, - first_slot, - }, - } - } - - fn sign(&self, message_hash: &Digest, rng: &mut impl Rng) -> XmssSignature { - let wots_signature = self.wots_secret_key.sign(message_hash, rng); - XmssSignature { - wots_signature, - merkle_proof: self.merkle_path.clone(), - slot: self.signature_slot, - } - } -} - -pub fn xmss_generate_phony_signatures( - log_lifetimes: &[usize], - message_hash: Digest, - slot: u64, -) -> (Vec, Vec) { - log_lifetimes - .par_iter() - .enumerate() - .map(|(i, &log_lifetime)| { - assert!(log_lifetime >= XMSS_MIN_LOG_LIFETIME); - assert!(log_lifetime <= XMSS_MAX_LOG_LIFETIME); - let mut rng = StdRng::seed_from_u64(i as u64); - let first_slot = slot - rng.random_range(0..(1 << log_lifetime).min(slot)); - let xmss_secret_key = PhonyXmssSecretKey::random(&mut rng, first_slot, log_lifetime, slot); - let signature = xmss_secret_key.sign(&message_hash, &mut rng); - (xmss_secret_key.public_key, signature) - }) - .unzip() -} diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index fcb38c52..181cf90a 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -36,8 +36,14 @@ impl WotsSecretKey { &self.public_key } - pub fn sign(&self, message_hash: &Digest, rng: &mut impl Rng) -> WotsSignature { - let (randomness, encoding) = find_randomness_for_wots_encoding(message_hash, rng); + pub fn sign( + &self, + message_hash: &Digest, + epoch: u32, + truncated_merkle_root: &[F; 6], + rng: &mut impl Rng, + ) -> WotsSignature { + let (randomness, encoding) = find_randomness_for_wots_encoding(message_hash, epoch, truncated_merkle_root, rng); WotsSignature { chain_tips: std::array::from_fn(|i| iterate_hash(&self.pre_images[i], encoding[i] as usize)), randomness, @@ -46,17 +52,37 @@ impl WotsSecretKey { } impl WotsSignature { - pub fn recover_public_key(&self, message_hash: &Digest, signature: &Self) -> Option { - self.recover_public_key_with_poseidon_trace(message_hash, signature, &mut Vec::new()) + pub fn recover_public_key( + &self, + message_hash: &Digest, + epoch: u32, + truncated_merkle_root: &[F; 6], + signature: &Self, + ) -> Option { + self.recover_public_key_with_poseidon_trace( + message_hash, + epoch, + truncated_merkle_root, + signature, + &mut Vec::new(), + ) } pub fn recover_public_key_with_poseidon_trace( &self, message_hash: &Digest, + epoch: u32, + truncated_merkle_root: &[F; 6], signature: &Self, poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>, ) -> Option { - let encoding = wots_encode_with_poseidon_trace(message_hash, &signature.randomness, poseidon_16_trace)?; + let encoding = wots_encode_with_poseidon_trace( + message_hash, + epoch, + truncated_merkle_root, + &signature.randomness, + poseidon_16_trace, + )?; Some(WotsPublicKey(std::array::from_fn(|i| { iterate_hash_with_poseidon_trace(&self.chain_tips[i], W - 1 - encoding[i] as usize, poseidon_16_trace) }))) @@ -76,9 +102,7 @@ impl WotsPublicKey { } pub fn iterate_hash(a: &Digest, n: usize) -> Digest { - (0..n).fold(*a, |acc, _| { - poseidon16_compress([acc, Default::default()].concat().try_into().unwrap()) - }) + (0..n).fold(*a, |acc, _| poseidon16_compress([acc, Default::default()].concat().try_into().unwrap())) } pub fn iterate_hash_with_poseidon_trace( @@ -91,25 +115,50 @@ pub fn iterate_hash_with_poseidon_trace( }) } -pub fn find_randomness_for_wots_encoding(message: &Digest, rng: &mut impl Rng) -> (Digest, [u8; V]) { +pub fn find_randomness_for_wots_encoding( + message: &Digest, + epoch: u32, + truncated_merkle_root: &[F; 6], + rng: &mut impl Rng, +) -> (Digest, [u8; V]) { loop { let randomness = rng.random(); - if let Some(encoding) = wots_encode(message, &randomness) { + if let Some(encoding) = wots_encode(message, epoch, truncated_merkle_root, &randomness) { return (randomness, encoding); } } } -pub fn wots_encode(message: &Digest, randomness: &Digest) -> Option<[u8; V]> { - wots_encode_with_poseidon_trace(message, randomness, &mut Vec::new()) +pub fn wots_encode( + message: &Digest, + epoch: u32, + truncated_merkle_root: &[F; 6], + randomness: &Digest, +) -> Option<[u8; V]> { + wots_encode_with_poseidon_trace(message, epoch, truncated_merkle_root, randomness, &mut Vec::new()) } pub fn wots_encode_with_poseidon_trace( message: &Digest, + epoch: u32, + truncated_merkle_root: &[F; 6], randomness: &Digest, poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>, ) -> Option<[u8; V]> { - let compressed = poseidon16_compress_with_trace(message, randomness, poseidon_16_trace); + // Encode epoch as 2 field elements (16 bits each) + let epoch_lo = F::from_usize((epoch & 0xFFFF) as usize); + let epoch_hi = F::from_usize(((epoch >> 16) & 0xFFFF) as usize); + + // A = poseidon(message (8 fe), epoch (2 fe), truncated_merkle_root (6 fe)) + let mut epoch_and_root = [F::default(); 8]; + epoch_and_root[0] = epoch_lo; + epoch_and_root[1] = epoch_hi; + epoch_and_root[2..8].copy_from_slice(truncated_merkle_root); + let a = poseidon16_compress_with_trace(message, &epoch_and_root, poseidon_16_trace); + + // B = poseidon(A (8 fe), randomness (8 fe)) + let compressed = poseidon16_compress_with_trace(&a, randomness, poseidon_16_trace); + if compressed.iter().any(|&kb| kb == -F::ONE) { return None; } diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index f6b4dad2..6c0d4f52 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -1,84 +1,106 @@ use multilinear_toolkit::prelude::*; -use rand::{SeedableRng, rngs::StdRng}; +use rand::{Rng, SeedableRng, rngs::StdRng}; use sha3::{Digest as Sha3Digest, Keccak256}; use crate::*; #[derive(Debug)] pub struct XmssSecretKey { - pub(crate) first_slot: u64, + pub(crate) start: u32, + pub(crate) end: u32, pub(crate) seed: [u8; 32], + // merkle_tree[level] stores nodes at that level whose subtree overlaps [start, end]. + // At level l, stored indices go from (start >> l) to (end >> l). pub(crate) merkle_tree: Vec>, } #[derive(Debug)] pub struct XmssSignature { pub wots_signature: WotsSignature, - pub slot: u64, // unused for now (Toy XMSS) + pub slot: u32, pub merkle_proof: Vec, } #[derive(Debug)] pub struct XmssPublicKey { pub merkle_root: Digest, - pub first_slot: u64, - pub log_lifetime: usize, } fn gen_wots_secret_key(seed: &[u8; 32], slot: u64) -> WotsSecretKey { let mut hasher = Keccak256::new(); + hasher.update(b"wots_secret_key"); hasher.update(seed); hasher.update(slot.to_le_bytes()); let mut rng = StdRng::from_seed(hasher.finalize().into()); WotsSecretKey::random(&mut rng) } +/// Deterministic pseudo-random digest for an out-of-range tree node. +fn gen_random_node(seed: &[u8; 32], level: usize, index: u64) -> Digest { + let mut hasher = Keccak256::new(); + hasher.update(b"random_node"); + hasher.update(seed); + hasher.update((level as u64).to_le_bytes()); + hasher.update(index.to_le_bytes()); + let mut rng = StdRng::from_seed(hasher.finalize().into()); + rng.random() +} + #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum XmssKeyGenError { - LogLifetimeTooSmall, - LogLifetimeTooLarge, - FirstSlotTooLarge, + InvalidRange, } -pub fn xmss_key_gen( - seed: [u8; 32], - first_slot: u64, - log_lifetime: usize, -) -> Result<(XmssSecretKey, XmssPublicKey), XmssKeyGenError> { - if first_slot >= (1 << XMSS_MAX_LOG_LIFETIME) { - return Err(XmssKeyGenError::FirstSlotTooLarge); - } - if log_lifetime < XMSS_MIN_LOG_LIFETIME { - return Err(XmssKeyGenError::LogLifetimeTooSmall); +pub fn xmss_key_gen(seed: [u8; 32], start: u32, end: u32) -> Result<(XmssSecretKey, XmssPublicKey), XmssKeyGenError> { + if start > end || end as u64 >= (1 << LOG_LIFETIME) { + return Err(XmssKeyGenError::InvalidRange); } - if log_lifetime > XMSS_MAX_LOG_LIFETIME { - return Err(XmssKeyGenError::LogLifetimeTooLarge); - } - let leaves = (first_slot..first_slot + (1 << log_lifetime)) + // Level 0: WOTS leaf hashes for slots in [start, end] + let leaves: Vec = (start as u64..end as u64 + 1) .into_par_iter() .map(|slot| { let wots = gen_wots_secret_key(&seed, slot); wots.public_key().hash() }) - .collect::>(); + .collect(); let mut merkle_tree = vec![leaves]; - for _ in 0..log_lifetime { - merkle_tree.push( - merkle_tree - .last() - .unwrap() - .par_chunks(2) - .map(|chunk| poseidon16_compress([chunk[0], chunk[1]].concat().try_into().unwrap())) - .collect(), - ); + // Build levels 1..=LOG_LIFETIME. + // At level l, we store nodes with index in [(start >> l), (end >> l)]. + // Children outside [start, end]'s subtree are replaced by gen_random_node. + for level in 1..=LOG_LIFETIME { + let base: u64 = (start as u64) >> level; + let top: u64 = (end as u64) >> level; + let prev_base: u64 = (start as u64) >> (level - 1); + let prev_top: u64 = (end as u64) >> (level - 1); + let nodes: Vec = { + let prev = &merkle_tree[level - 1]; + (base..top + 1) + .into_par_iter() + .map(|i| { + let left_idx = 2 * i; + let right_idx = 2 * i + 1; + let left = if left_idx >= prev_base && left_idx <= prev_top { + prev[(left_idx - prev_base) as usize] + } else { + gen_random_node(&seed, level - 1, left_idx) + }; + let right = if right_idx >= prev_base && right_idx <= prev_top { + prev[(right_idx - prev_base) as usize] + } else { + gen_random_node(&seed, level - 1, right_idx) + }; + poseidon16_compress([left, right].concat().try_into().unwrap()) + }) + .collect() + }; + merkle_tree.push(nodes); } let pub_key = XmssPublicKey { - first_slot, - merkle_root: *merkle_tree.last().unwrap().first().unwrap(), - log_lifetime, + merkle_root: merkle_tree.last().unwrap()[0], }; let secret_key = XmssSecretKey { - first_slot, + start, + end, seed, merkle_tree, }; @@ -87,34 +109,33 @@ pub fn xmss_key_gen( #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum XmssSignatureError { - SlotTooEarly, - SlotTooLate, + SlotOutOfRange, } pub fn xmss_sign( randomness_seed: [u8; 32], secret_key: &XmssSecretKey, message_hash: &[F; 8], - slot: u64, + slot: u32, ) -> Result { - if slot < secret_key.first_slot { - return Err(XmssSignatureError::SlotTooEarly); + if slot < secret_key.start || slot > secret_key.end { + return Err(XmssSignatureError::SlotOutOfRange); } - let wots_index = slot - secret_key.first_slot; - if wots_index >= secret_key.lifetime() { - return Err(XmssSignatureError::SlotTooLate); - } - let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot); + let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot as u64); + let merkle_root = secret_key.public_key().merkle_root; + let truncated_merkle_root: [F; 6] = merkle_root[0..6].try_into().unwrap(); let mut rng = StdRng::from_seed(randomness_seed); - let wots_signature = wots_secret_key.sign(message_hash, &mut rng); - let merkle_proof = (0..secret_key.log_lifetime()) - .scan(wots_index, |current_idx, level| { - let neighbour_index = *current_idx ^ 1; - // TODO edge case if usize = u32 and MAX_LOG_LIFETIME > 32 ? (unlikely) - let neighbour = secret_key.merkle_tree[level][neighbour_index as usize]; - // Move up to the next level. - *current_idx /= 2; - Some(neighbour) + let wots_signature = wots_secret_key.sign(message_hash, slot, &truncated_merkle_root, &mut rng); + let merkle_proof = (0..LOG_LIFETIME) + .map(|level| { + let neighbour_index = ((slot as u64) >> level) ^ 1; + let base = (secret_key.start as u64) >> level; + let top = (secret_key.end as u64) >> level; + if neighbour_index >= base && neighbour_index <= top { + secret_key.merkle_tree[level][(neighbour_index - base) as usize] + } else { + gen_random_node(&secret_key.seed, level, neighbour_index) + } }) .collect(); Ok(XmssSignature { @@ -125,27 +146,15 @@ pub fn xmss_sign( } impl XmssSecretKey { - pub fn log_lifetime(&self) -> usize { - self.merkle_tree.len() - 1 - } - - pub fn lifetime(&self) -> u64 { - 1 << self.log_lifetime() - } - pub fn public_key(&self) -> XmssPublicKey { XmssPublicKey { - first_slot: self.first_slot, merkle_root: self.merkle_tree.last().unwrap()[0], - log_lifetime: self.log_lifetime(), } } } #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum XmssVerifyError { - SlotTooEarly, - SlotTooLate, InvalidWots, InvalidMerklePath, } @@ -154,9 +163,7 @@ pub fn xmss_verify( pub_key: &XmssPublicKey, message_hash: &Digest, signature: &XmssSignature, - slot: u64, ) -> Result<(), XmssVerifyError> { - let _ = slot; // TODO xmss_verify_with_poseidon_trace(pub_key, message_hash, signature).map(|_| ()) } @@ -165,27 +172,27 @@ pub fn xmss_verify_with_poseidon_trace( message_hash: &Digest, signature: &XmssSignature, ) -> Result { - if signature.slot < pub_key.first_slot { - return Err(XmssVerifyError::SlotTooEarly); - } - let wots_index = signature.slot - pub_key.first_slot; - if wots_index >= (1 << pub_key.log_lifetime) { - return Err(XmssVerifyError::SlotTooLate); - } let mut poseidon_16_trace = Vec::new(); + let truncated_merkle_root: [F; 6] = pub_key.merkle_root[0..6].try_into().unwrap(); let wots_public_key = signature .wots_signature - .recover_public_key_with_poseidon_trace(message_hash, &signature.wots_signature, &mut poseidon_16_trace) + .recover_public_key_with_poseidon_trace( + message_hash, + signature.slot, + &truncated_merkle_root, + &signature.wots_signature, + &mut poseidon_16_trace, + ) .ok_or(XmssVerifyError::InvalidWots)?; - // merkle root verification let mut current_hash = wots_public_key.hash_with_poseidon_trace(&mut poseidon_16_trace); - if signature.merkle_proof.len() != pub_key.log_lifetime { + if signature.merkle_proof.len() != LOG_LIFETIME { return Err(XmssVerifyError::InvalidMerklePath); } + let wots_index = signature.slot as u64; for (level, neighbour) in signature.merkle_proof.iter().enumerate() { let is_left = ((wots_index >> level) & 1) == 0; if is_left { - current_hash = poseidon16_compress_with_trace(¤t_hash, neighbour, &mut poseidon_16_trace) + current_hash = poseidon16_compress_with_trace(¤t_hash, neighbour, &mut poseidon_16_trace); } else { current_hash = poseidon16_compress_with_trace(neighbour, ¤t_hash, &mut poseidon_16_trace); } diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs new file mode 100644 index 00000000..e6e410c2 --- /dev/null +++ b/crates/xmss/tests/xmss_tests.rs @@ -0,0 +1,18 @@ +use multilinear_toolkit::prelude::PrimeCharacteristicRing; +use p3_koala_bear::KoalaBear; +use xmss::*; + +type F = KoalaBear; + +#[test] +fn keygen_sign_verify() { + let keygen_seed: [u8; 32] = std::array::from_fn(|i| i as u8); + let randomness_seed: [u8; 32] = std::array::from_fn(|i| (i * 2 + 1) as u8); + let message: [F; 8] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); + + let (sk, pk) = xmss_key_gen(keygen_seed, 100, 115).unwrap(); + for slot in 100..=115 { + let sig = xmss_sign(randomness_seed, &sk, &message, slot).unwrap(); + xmss_verify(&pk, &message, &sig).unwrap(); + } +} From b7776673fcf6d3b1303002879fb461fab1d0236b Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 03:12:52 +0400 Subject: [PATCH 02/29] V_GRINDING --- crates/xmss/src/lib.rs | 5 +++-- crates/xmss/src/wots.rs | 38 ++++++++++++++++++++++++--------- crates/xmss/tests/xmss_tests.rs | 23 +++++++++++++++++++- 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index d6c2ad85..9e2be5bc 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -20,8 +20,9 @@ type Digest = [F; DIGEST_SIZE]; pub const V: usize = 66; pub const W: usize = 4; pub const CHAIN_LENGTH: usize = 1 << W; -pub const D: usize = 80; -pub const TARGET_SUM: usize = V * (W - 1) - D; +pub const NUM_CHAIN_HASHES: usize = 80; +pub const TARGET_SUM: usize = V * (W - 1) - NUM_CHAIN_HASHES; +pub const V_GRINDING: usize = 1; pub const LOG_LIFETIME: usize = 32; pub type Poseidon16History = Vec<([F; 16], [F; 8])>; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 181cf90a..7be56ae0 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -43,7 +43,8 @@ impl WotsSecretKey { truncated_merkle_root: &[F; 6], rng: &mut impl Rng, ) -> WotsSignature { - let (randomness, encoding) = find_randomness_for_wots_encoding(message_hash, epoch, truncated_merkle_root, rng); + let (randomness, encoding, _) = + find_randomness_for_wots_encoding(message_hash, epoch, truncated_merkle_root, rng); WotsSignature { chain_tips: std::array::from_fn(|i| iterate_hash(&self.pre_images[i], encoding[i] as usize)), randomness, @@ -102,7 +103,9 @@ impl WotsPublicKey { } pub fn iterate_hash(a: &Digest, n: usize) -> Digest { - (0..n).fold(*a, |acc, _| poseidon16_compress([acc, Default::default()].concat().try_into().unwrap())) + (0..n).fold(*a, |acc, _| { + poseidon16_compress([acc, Default::default()].concat().try_into().unwrap()) + }) } pub fn iterate_hash_with_poseidon_trace( @@ -120,11 +123,13 @@ pub fn find_randomness_for_wots_encoding( epoch: u32, truncated_merkle_root: &[F; 6], rng: &mut impl Rng, -) -> (Digest, [u8; V]) { +) -> (Digest, [u8; V], usize) { + let mut num_iters = 0; loop { + num_iters += 1; let randomness = rng.random(); if let Some(encoding) = wots_encode(message, epoch, truncated_merkle_root, &randomness) { - return (randomness, encoding); + return (randomness, encoding, num_iters); } } } @@ -162,12 +167,12 @@ pub fn wots_encode_with_poseidon_trace( if compressed.iter().any(|&kb| kb == -F::ONE) { return None; } - let encoding: Vec<_> = compressed + let all_indices: Vec<_> = compressed .iter() .flat_map(|kb| to_little_endian_bits(kb.to_usize(), 24)) .collect::>() .chunks_exact(log2_strict_usize(W)) - .take(V) + .take(V + V_GRINDING) .map(|chunk| { chunk .iter() @@ -175,11 +180,24 @@ pub fn wots_encode_with_poseidon_trace( .fold(0u8, |acc, (i, &bit)| acc | (u8::from(bit) << i)) }) .collect(); - is_valid_encoding(&encoding).then(|| encoding.try_into().unwrap()) + is_valid_encoding(&all_indices).then(|| all_indices[..V].try_into().unwrap()) } fn is_valid_encoding(encoding: &[u8]) -> bool { - encoding.len() == V - && encoding.iter().all(|&x| (x as usize) < W) - && encoding.iter().map(|&x| x as usize).sum::() == TARGET_SUM + if encoding.len() != V + V_GRINDING { + return false; + } + // All indices must be < W + if !encoding.iter().all(|&x| (x as usize) < W) { + return false; + } + // First V indices must sum to TARGET_SUM + if encoding[..V].iter().map(|&x| x as usize).sum::() != TARGET_SUM { + return false; + } + // Last V_GRINDING indices must all be W-1 (grinding constraint) + if !encoding[V..].iter().all(|&x| x as usize == W - 1) { + return false; + } + true } diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index e6e410c2..283870bb 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -1,5 +1,6 @@ -use multilinear_toolkit::prelude::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use p3_koala_bear::KoalaBear; +use rand::{SeedableRng, rngs::StdRng}; use xmss::*; type F = KoalaBear; @@ -16,3 +17,23 @@ fn keygen_sign_verify() { xmss_verify(&pk, &message, &sig).unwrap(); } } + +#[test] +#[ignore] +fn encoding_grinding_bits() { + let n = 1000; + let total_iters = (0..10_000) + .into_par_iter() + .map(|i| { + let message: [F; 8] = Default::default(); + let epoch = i as u32; + let truncated_merkle_root: [F; 6] = Default::default(); + let mut rng = StdRng::seed_from_u64(i as u64); + let (_randomness, _encoding, num_iters) = + find_randomness_for_wots_encoding(&message, epoch, &truncated_merkle_root, &mut rng); + num_iters + }) + .sum::(); + let grinding = ((total_iters as f64) / (n as f64)).log2(); + println!("Average grinding bits: {:.1}", grinding); +} From 2ee393438a475bcf357c519637d33c8244f0300b Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 11:32:19 +0400 Subject: [PATCH 03/29] wip --- crates/rec_aggregation/xmss_aggregate.py | 15 ++++++--------- crates/xmss/tests/xmss_tests.rs | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 86c720c8..56111ab9 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -4,9 +4,6 @@ W = 4 TARGET_SUM = 118 MAX_LOG_LIFETIME = 30 - -V_HALF = V / 2 # V should be even - VECTOR_LEN = 8 # Dot product precompile: @@ -115,7 +112,7 @@ def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): dot_product(var_1, ONE_VEC_PTR, var_2, 1, EE) dot_product(var_3, ONE_VEC_PTR, var_4, 1, EE) - wots_pubkey_hashed = slice_hash(ZERO_VEC_PTR, public_key, V_HALF) + wots_pubkey_hashed = slice_hash(public_key, V) debug_assert(log_lifetime < MAX_LOG_LIFETIME + 1) @@ -223,15 +220,15 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): return state_indexes[height - 1] -def slice_hash(seed, data, half_len: Const): - states = Array(half_len * 2 * VECTOR_LEN) +def slice_hash(data, len: Const): + states = Array(len * VECTOR_LEN) poseidon16(ZERO_VEC_PTR, data, states) - state_indexes = Array(half_len * 2) + state_indexes = Array(len) state_indexes[0] = states - for j in unroll(1, (half_len * 2)): + for j in unroll(1, (len)): state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN poseidon16(state_indexes[j - 1], data + j * VECTOR_LEN, state_indexes[j]) - return state_indexes[half_len * 2 - 1] + return state_indexes[len - 1] @inline diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index 283870bb..626b82e8 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -21,8 +21,8 @@ fn keygen_sign_verify() { #[test] #[ignore] fn encoding_grinding_bits() { - let n = 1000; - let total_iters = (0..10_000) + let n = 100; + let total_iters = (0..n) .into_par_iter() .map(|i| { let message: [F; 8] = Default::default(); From 982ac0155fed87808889987b85321d50c21b7d15 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 11:56:32 +0400 Subject: [PATCH 04/29] wip --- Cargo.lock | 1 - crates/lean_prover/src/test_zkvm.rs | 4 ++-- crates/rec_aggregation/hashing.py | 20 ++++++++++---------- crates/rec_aggregation/recursion.py | 2 +- crates/rec_aggregation/utils.py | 2 +- crates/rec_aggregation/whir.py | 4 ++-- crates/xmss/Cargo.toml | 1 - crates/xmss/src/lib.rs | 4 ++-- crates/xmss/src/wots.rs | 19 +++++++++++-------- 9 files changed, 29 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5b87285..d4025b25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1354,7 +1354,6 @@ version = "0.1.0" dependencies = [ "multilinear-toolkit", "p3-koala-bear", - "p3-util", "rand", "sha3", "utils", diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 45de7190..0078dd27 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -20,7 +20,7 @@ fn test_zk_vm_all_precompiles_helper(fuzzing: bool) { let program_str = r#" DIM = 5 N = 11 -VECTOR_LEN = 8 +DIGEST_LEN = 8 # Dot product precompile: BE = 1 # base-extension @@ -28,7 +28,7 @@ EE = 0 # extension-extension def main(): pub_start = NONRESERVED_PROGRAM_INPUT_START - poseidon16(pub_start + 4 * VECTOR_LEN, pub_start + 5 * VECTOR_LEN, pub_start + 6 * VECTOR_LEN) + poseidon16(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, pub_start + 6 * DIGEST_LEN) dot_product(pub_start + 88, pub_start + 88 + N, pub_start + 1000, N, BE) dot_product(pub_start + 88 + N, pub_start + 88 + N * (DIM + 1), pub_start + 1000 + DIM, N, EE) c: Mut = 0 diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index 4c6c0f07..fe368f1c 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -1,7 +1,7 @@ from snark_lib import * DIM = 5 # extension degree -VECTOR_LEN = 8 +DIGEST_LEN = 8 WHIR_MERKLE_HEIGHTS = WHIR_MERKLE_HEIGHTS_PLACEHOLDER WHIR_NUM_QUERIES = WHIR_NUM_QUERIES_PLACEHOLDER @@ -44,13 +44,13 @@ def batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, @inline def slice_hash(data, len): - states = Array((len - 1) * VECTOR_LEN) - poseidon16(data, data + VECTOR_LEN, states) + states = Array((len - 1) * DIGEST_LEN) + poseidon16(data, data + DIGEST_LEN, states) state_indexes = Array(len) state_indexes[0] = states for j in unroll(1, len - 1): - state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN - poseidon16(state_indexes[j - 1], data + (j + 1) * VECTOR_LEN, state_indexes[j]) + state_indexes[j] = state_indexes[j - 1] + DIGEST_LEN + poseidon16(state_indexes[j - 1], data + (j + 1) * DIGEST_LEN, state_indexes[j]) return state_indexes[len - 2] @@ -81,7 +81,7 @@ def merkle_verif_batch_const(n_paths: Const, merkle_paths, leaves_digests, leave for i in unroll(0, n_paths): merkle_verify( leaves_digests[i], - merkle_paths + (i * height) * VECTOR_LEN, + merkle_paths + (i * height) * DIGEST_LEN, leave_positions[i], root, height, @@ -91,7 +91,7 @@ def merkle_verif_batch_const(n_paths: Const, merkle_paths, leaves_digests, leave def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, root, height: Const): - states = Array(height * VECTOR_LEN) + states = Array(height * DIGEST_LEN) # First merkle round match leaf_position_bits[0]: @@ -104,18 +104,18 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, root, height: Co state_indexes = Array(height) state_indexes[0] = states for j in unroll(1, height): - state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN + state_indexes[j] = state_indexes[j - 1] + DIGEST_LEN # Warning: this works only if leaf_position_bits[i] is known to be boolean: match leaf_position_bits[j]: case 0: poseidon16( state_indexes[j - 1], - merkle_path + j * VECTOR_LEN, + merkle_path + j * DIGEST_LEN, state_indexes[j], ) case 1: poseidon16( - merkle_path + j * VECTOR_LEN, + merkle_path + j * DIGEST_LEN, state_indexes[j - 1], state_indexes[j], ) diff --git a/crates/rec_aggregation/recursion.py b/crates/rec_aggregation/recursion.py index 9c8a51d7..6b5a8f83 100644 --- a/crates/rec_aggregation/recursion.py +++ b/crates/rec_aggregation/recursion.py @@ -61,7 +61,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip fs: Mut = fs_new(proof_transcript) # table dims - debug_assert(N_TABLES + 1 < VECTOR_LEN) + debug_assert(N_TABLES + 1 < DIGEST_LEN) fs, mem_and_table_dims = fs_receive_chunks(fs, 1) for i in unroll(N_TABLES + 1, 8): assert mem_and_table_dims[i] == 0 diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index f3af2d41..7b72ed11 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -448,7 +448,7 @@ def print_ef(a): def print_vec(a): - for i in unroll(0, VECTOR_LEN): + for i in unroll(0, DIGEST_LEN): print(a[i]) return diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index 1d03cdfb..06318911 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -201,9 +201,9 @@ def sample_stir_indexes_and_fold( answers[i] = answer leaf_hashes = Array(num_queries) # a vector of vectorized pointers, each pointing to 1 chunk of 8 field elements - batch_hash_slice(num_queries, answers, leaf_hashes, n_chunks_per_answer / VECTOR_LEN) + batch_hash_slice(num_queries, answers, leaf_hashes, n_chunks_per_answer / DIGEST_LEN) - fs, merkle_paths = fs_hint(fs, folded_domain_size * num_queries * VECTOR_LEN) + fs, merkle_paths = fs_hint(fs, folded_domain_size * num_queries * DIGEST_LEN) # Merkle verification merkle_verif_batch( diff --git a/crates/xmss/Cargo.toml b/crates/xmss/Cargo.toml index b6cff886..b1705363 100644 --- a/crates/xmss/Cargo.toml +++ b/crates/xmss/Cargo.toml @@ -10,6 +10,5 @@ workspace = true p3-koala-bear.workspace = true rand.workspace = true utils.workspace = true -p3-util.workspace = true multilinear-toolkit.workspace = true sha3.workspace = true \ No newline at end of file diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 9e2be5bc..b90abd67 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -18,10 +18,10 @@ type Digest = [F; DIGEST_SIZE]; // WOTS pub const V: usize = 66; -pub const W: usize = 4; +pub const W: usize = 2; pub const CHAIN_LENGTH: usize = 1 << W; pub const NUM_CHAIN_HASHES: usize = 80; -pub const TARGET_SUM: usize = V * (W - 1) - NUM_CHAIN_HASHES; +pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; pub const V_GRINDING: usize = 1; pub const LOG_LIFETIME: usize = 32; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 7be56ae0..e50add45 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -1,5 +1,4 @@ use multilinear_toolkit::prelude::*; -use p3_util::log2_strict_usize; use rand::{Rng, RngCore}; use utils::{ToUsize, to_little_endian_bits}; @@ -28,7 +27,7 @@ impl WotsSecretKey { pub fn new(pre_images: [Digest; V]) -> Self { Self { pre_images, - public_key: WotsPublicKey(std::array::from_fn(|i| iterate_hash(&pre_images[i], W - 1))), + public_key: WotsPublicKey(std::array::from_fn(|i| iterate_hash(&pre_images[i], CHAIN_LENGTH - 1))), } } @@ -85,7 +84,11 @@ impl WotsSignature { poseidon_16_trace, )?; Some(WotsPublicKey(std::array::from_fn(|i| { - iterate_hash_with_poseidon_trace(&self.chain_tips[i], W - 1 - encoding[i] as usize, poseidon_16_trace) + iterate_hash_with_poseidon_trace( + &self.chain_tips[i], + CHAIN_LENGTH - 1 - encoding[i] as usize, + poseidon_16_trace, + ) }))) } } @@ -171,7 +174,7 @@ pub fn wots_encode_with_poseidon_trace( .iter() .flat_map(|kb| to_little_endian_bits(kb.to_usize(), 24)) .collect::>() - .chunks_exact(log2_strict_usize(W)) + .chunks_exact(W) .take(V + V_GRINDING) .map(|chunk| { chunk @@ -187,16 +190,16 @@ fn is_valid_encoding(encoding: &[u8]) -> bool { if encoding.len() != V + V_GRINDING { return false; } - // All indices must be < W - if !encoding.iter().all(|&x| (x as usize) < W) { + // All indices must be < CHAIN_LENGTH + if !encoding.iter().all(|&x| (x as usize) < CHAIN_LENGTH) { return false; } // First V indices must sum to TARGET_SUM if encoding[..V].iter().map(|&x| x as usize).sum::() != TARGET_SUM { return false; } - // Last V_GRINDING indices must all be W-1 (grinding constraint) - if !encoding[V..].iter().all(|&x| x as usize == W - 1) { + // Last V_GRINDING indices must all be CHAIN_LENGTH-1 (grinding constraint) + if !encoding[V..].iter().all(|&x| x as usize == CHAIN_LENGTH - 1) { return false; } true From b9f128fb56f2a7c2be4116035aaabff66cf8db3b Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 12:04:33 +0400 Subject: [PATCH 05/29] 9 fe per message --- crates/xmss/src/wots.rs | 44 ++++++++++++++++----------------- crates/xmss/src/xmss.rs | 14 +++++------ crates/xmss/tests/xmss_tests.rs | 6 ++--- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index e50add45..e8608dfe 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -37,13 +37,12 @@ impl WotsSecretKey { pub fn sign( &self, - message_hash: &Digest, + message: &[F; 9], epoch: u32, - truncated_merkle_root: &[F; 6], + truncated_merkle_root: &[F; 5], rng: &mut impl Rng, ) -> WotsSignature { - let (randomness, encoding, _) = - find_randomness_for_wots_encoding(message_hash, epoch, truncated_merkle_root, rng); + let (randomness, encoding, _) = find_randomness_for_wots_encoding(message, epoch, truncated_merkle_root, rng); WotsSignature { chain_tips: std::array::from_fn(|i| iterate_hash(&self.pre_images[i], encoding[i] as usize)), randomness, @@ -54,13 +53,13 @@ impl WotsSecretKey { impl WotsSignature { pub fn recover_public_key( &self, - message_hash: &Digest, + message: &[F; 9], epoch: u32, - truncated_merkle_root: &[F; 6], + truncated_merkle_root: &[F; 5], signature: &Self, ) -> Option { self.recover_public_key_with_poseidon_trace( - message_hash, + message, epoch, truncated_merkle_root, signature, @@ -70,14 +69,14 @@ impl WotsSignature { pub fn recover_public_key_with_poseidon_trace( &self, - message_hash: &Digest, + message: &[F; 9], epoch: u32, - truncated_merkle_root: &[F; 6], + truncated_merkle_root: &[F; 5], signature: &Self, poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>, ) -> Option { let encoding = wots_encode_with_poseidon_trace( - message_hash, + message, epoch, truncated_merkle_root, &signature.randomness, @@ -122,9 +121,9 @@ pub fn iterate_hash_with_poseidon_trace( } pub fn find_randomness_for_wots_encoding( - message: &Digest, + message: &[F; 9], epoch: u32, - truncated_merkle_root: &[F; 6], + truncated_merkle_root: &[F; 5], rng: &mut impl Rng, ) -> (Digest, [u8; V], usize) { let mut num_iters = 0; @@ -138,18 +137,18 @@ pub fn find_randomness_for_wots_encoding( } pub fn wots_encode( - message: &Digest, + message: &[F; 9], epoch: u32, - truncated_merkle_root: &[F; 6], + truncated_merkle_root: &[F; 5], randomness: &Digest, ) -> Option<[u8; V]> { wots_encode_with_poseidon_trace(message, epoch, truncated_merkle_root, randomness, &mut Vec::new()) } pub fn wots_encode_with_poseidon_trace( - message: &Digest, + message: &[F; 9], epoch: u32, - truncated_merkle_root: &[F; 6], + truncated_merkle_root: &[F; 5], randomness: &Digest, poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>, ) -> Option<[u8; V]> { @@ -157,12 +156,13 @@ pub fn wots_encode_with_poseidon_trace( let epoch_lo = F::from_usize((epoch & 0xFFFF) as usize); let epoch_hi = F::from_usize(((epoch >> 16) & 0xFFFF) as usize); - // A = poseidon(message (8 fe), epoch (2 fe), truncated_merkle_root (6 fe)) - let mut epoch_and_root = [F::default(); 8]; - epoch_and_root[0] = epoch_lo; - epoch_and_root[1] = epoch_hi; - epoch_and_root[2..8].copy_from_slice(truncated_merkle_root); - let a = poseidon16_compress_with_trace(message, &epoch_and_root, poseidon_16_trace); + // A = poseidon(message (9 fe), epoch (2 fe), truncated_merkle_root (5 fe)) + let mut a_input_right = [F::default(); 8]; + a_input_right[0] = message[8]; + a_input_right[1] = epoch_lo; + a_input_right[2] = epoch_hi; + a_input_right[3..8].copy_from_slice(truncated_merkle_root); + let a = poseidon16_compress_with_trace(message[..8].try_into().unwrap(), &a_input_right, poseidon_16_trace); // B = poseidon(A (8 fe), randomness (8 fe)) let compressed = poseidon16_compress_with_trace(&a, randomness, poseidon_16_trace); diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 6c0d4f52..5d743127 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -115,7 +115,7 @@ pub enum XmssSignatureError { pub fn xmss_sign( randomness_seed: [u8; 32], secret_key: &XmssSecretKey, - message_hash: &[F; 8], + message_hash: &[F; 9], slot: u32, ) -> Result { if slot < secret_key.start || slot > secret_key.end { @@ -123,7 +123,7 @@ pub fn xmss_sign( } let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot as u64); let merkle_root = secret_key.public_key().merkle_root; - let truncated_merkle_root: [F; 6] = merkle_root[0..6].try_into().unwrap(); + let truncated_merkle_root: [F; 5] = merkle_root[0..5].try_into().unwrap(); let mut rng = StdRng::from_seed(randomness_seed); let wots_signature = wots_secret_key.sign(message_hash, slot, &truncated_merkle_root, &mut rng); let merkle_proof = (0..LOG_LIFETIME) @@ -161,23 +161,23 @@ pub enum XmssVerifyError { pub fn xmss_verify( pub_key: &XmssPublicKey, - message_hash: &Digest, + message: &[F; 9], signature: &XmssSignature, ) -> Result<(), XmssVerifyError> { - xmss_verify_with_poseidon_trace(pub_key, message_hash, signature).map(|_| ()) + xmss_verify_with_poseidon_trace(pub_key, message, signature).map(|_| ()) } pub fn xmss_verify_with_poseidon_trace( pub_key: &XmssPublicKey, - message_hash: &Digest, + message: &[F; 9], signature: &XmssSignature, ) -> Result { let mut poseidon_16_trace = Vec::new(); - let truncated_merkle_root: [F; 6] = pub_key.merkle_root[0..6].try_into().unwrap(); + let truncated_merkle_root: [F; 5] = pub_key.merkle_root[0..5].try_into().unwrap(); let wots_public_key = signature .wots_signature .recover_public_key_with_poseidon_trace( - message_hash, + message, signature.slot, &truncated_merkle_root, &signature.wots_signature, diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index 626b82e8..ff041899 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -9,7 +9,7 @@ type F = KoalaBear; fn keygen_sign_verify() { let keygen_seed: [u8; 32] = std::array::from_fn(|i| i as u8); let randomness_seed: [u8; 32] = std::array::from_fn(|i| (i * 2 + 1) as u8); - let message: [F; 8] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); + let message: [F; 9] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); let (sk, pk) = xmss_key_gen(keygen_seed, 100, 115).unwrap(); for slot in 100..=115 { @@ -25,9 +25,9 @@ fn encoding_grinding_bits() { let total_iters = (0..n) .into_par_iter() .map(|i| { - let message: [F; 8] = Default::default(); + let message: [F; 9] = Default::default(); let epoch = i as u32; - let truncated_merkle_root: [F; 6] = Default::default(); + let truncated_merkle_root: [F; 5] = Default::default(); let mut rng = StdRng::seed_from_u64(i as u64); let (_randomness, _encoding, num_iters) = find_randomness_for_wots_encoding(&message, epoch, &truncated_merkle_root, &mut rng); From c8a6550cfe8cd88ed5e0a7f3c403c38c6774e8e1 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 13:24:35 +0400 Subject: [PATCH 06/29] wip --- crates/rec_aggregation/src/xmss_aggregate.rs | 106 ++++++++++--------- crates/rec_aggregation/xmss_aggregate.py | 66 ++++++------ crates/xmss/src/lib.rs | 4 + crates/xmss/src/wots.rs | 66 ++++++------ crates/xmss/src/xmss.rs | 16 ++- crates/xmss/tests/xmss_tests.rs | 11 +- 6 files changed, 138 insertions(+), 131 deletions(-) diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index c61b2a08..6e4abb9b 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -3,14 +3,13 @@ use lean_prover::{default_whir_config, prove_execution::prove_execution, verify_ use lean_vm::*; use multilinear_toolkit::prelude::*; use rand::{Rng, SeedableRng, rngs::StdRng}; -use std::path::Path; use std::sync::OnceLock; use std::time::Instant; +use std::{collections::BTreeMap, path::Path}; use tracing::{info_span, instrument}; -use utils::to_little_endian_in_field; use xmss::{ - Poseidon16History, V, XMSS_MAX_LOG_LIFETIME, XmssPublicKey, XmssSignature, xmss_generate_phony_signatures, - xmss_verify_with_poseidon_trace, + LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, SIG_SIZE_FE, TARGET_SUM, V, W, XmssPublicKey, XmssSignature, + slot_to_field_elements, xmss_verify_with_poseidon_trace, }; static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); @@ -24,25 +23,21 @@ pub fn xmss_setup_aggregation_program() { let _ = get_xmss_aggregation_program(); } -fn build_public_input(xmss_pub_keys: &[XmssPublicKey], message_hash: [F; 8], slot: u64) -> Vec { - let mut public_input = vec![F::from_usize(xmss_pub_keys.len())]; - public_input.extend(message_hash.to_vec()); - public_input.extend(xmss_pub_keys.iter().flat_map(|pk| pk.merkle_root)); - public_input.extend(xmss_pub_keys.iter().map(|pk| F::from_usize(pk.log_lifetime))); - for pk in xmss_pub_keys { - let index_in_merkle_tree = slot.checked_sub(pk.first_slot).unwrap() as usize; - public_input.extend(to_little_endian_in_field::( - index_in_merkle_tree, - XMSS_MAX_LOG_LIFETIME, - )); - } - let mut acc = F::ZERO; - for pk in xmss_pub_keys { - public_input.push(acc); - acc += F::from_usize((1 + V + pk.log_lifetime) * DIGEST_LEN); // size of the signature +fn build_public_input(xmss_pub_keys: &[XmssPublicKey], message: [F; MESSAGE_LEN_FE], slot: u32) -> Vec { + let mut public_input = vec![]; + public_input.push(F::ZERO); // private input start, filled later + public_input.push(F::from_usize(xmss_pub_keys.len())); + public_input.extend(message.to_vec()); + let [slot_lo, slot_hi] = slot_to_field_elements(slot); + public_input.push(slot_lo); + public_input.push(slot_hi); + for level in 0..LOG_LIFETIME { + let is_left = (((slot as u64) >> level) & 1) == 0; + public_input.push(F::from_usize(is_left as usize)); } - let private_input_start = (NONRESERVED_PROGRAM_INPUT_START + 1 + public_input.len()).next_power_of_two(); - public_input.insert(0, F::from_usize(private_input_start)); + public_input.extend(xmss_pub_keys.iter().flat_map(|pk| pk.merkle_root)); + let private_input_start = (NONRESERVED_PROGRAM_INPUT_START + public_input.len()).next_power_of_two(); + public_input[0] = F::from_usize(private_input_start); public_input } @@ -63,7 +58,7 @@ fn build_private_input(all_signatures: &[XmssSignature]) -> Vec { } let sig_size = private_input.len() - initial_private_input_len; - assert!(sig_size.is_multiple_of(DIGEST_LEN)); + assert_eq!(sig_size, SIG_SIZE_FE); } private_input } @@ -75,33 +70,48 @@ fn compile_xmss_aggregation_program() -> Bytecode { .to_str() .unwrap() .to_string(); - compile_program(&ProgramSource::Filepath(filepath)) + let mut replacements = BTreeMap::new(); + replacements.insert("V_PLACEHOLDER".to_string(), V.to_string()); + replacements.insert("W_PLACEHOLDER".to_string(), W.to_string()); + replacements.insert("TARGET_SUM_PLACEHOLDER".to_string(), TARGET_SUM.to_string()); + replacements.insert("LOG_LIFETIME_PLACEHOLDER".to_string(), LOG_LIFETIME.to_string()); + compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements }) } -pub fn run_xmss_benchmark(log_lifetimes: &[usize], tracing: bool) { +pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { if tracing { utils::init_tracing(); } xmss_setup_aggregation_program(); precompute_dft_twiddles::(1 << 24); - let mut rng = StdRng::seed_from_u64(0); - let message_hash: [F; 8] = rng.random(); + let message: [F; MESSAGE_LEN_FE] = (0..9).map(|i| F::from_usize(i)).collect::>().try_into().unwrap(); let slot = 1111; - let (xmss_pub_keys, all_signatures) = xmss_generate_phony_signatures(log_lifetimes, message_hash, slot); + let pub_keys_and_sigs = (0..n_signatures) + .into_par_iter() + .map(|i| { + let mut rng = StdRng::seed_from_u64(i as u64); + let start = slot - rng.random_range(0..5); + let end = slot + rng.random_range(1..5); + let (sk, pk) = xmss::xmss_key_gen(rng.random(), start, end).unwrap(); + let sig = xmss::xmss_sign(&mut rng, &sk, &message, slot).unwrap(); + (pk, sig) + }) + .collect::>(); + let (xmss_pub_keys, all_signatures): (Vec<_>, Vec<_>) = pub_keys_and_sigs.into_iter().unzip(); let time = Instant::now(); let (proof_data, n_field_elements_in_proof, summary) = - xmss_aggregate_signatures_helper(&xmss_pub_keys, &all_signatures, message_hash, slot).unwrap(); + xmss_aggregate_signatures_helper(&xmss_pub_keys, &all_signatures, message, slot).unwrap(); let proving_time = time.elapsed(); - xmss_verify_aggregated_signatures(&xmss_pub_keys, message_hash, &proof_data, slot).unwrap(); + xmss_verify_aggregated_signatures(&xmss_pub_keys, message, &proof_data, slot).unwrap(); println!("{summary}"); println!( "XMSS aggregation, proving time: {:.3} s ({:.1} XMSS/s), proof size: {} KiB", proving_time.as_secs_f64(), - log_lifetimes.len() as f64 / proving_time.as_secs_f64(), + n_signatures as f64 / proving_time.as_secs_f64(), n_field_elements_in_proof * F::bits() / (8 * 1024) ); } @@ -115,17 +125,17 @@ pub enum XmssAggregateError { pub fn xmss_aggregate_signatures( xmss_pub_keys: &[XmssPublicKey], all_signatures: &[XmssSignature], - message_hash: [F; 8], - slot: u64, + message: [F; MESSAGE_LEN_FE], + slot: u32, ) -> Result, XmssAggregateError> { - Ok(xmss_aggregate_signatures_helper(xmss_pub_keys, all_signatures, message_hash, slot)?.0) + Ok(xmss_aggregate_signatures_helper(xmss_pub_keys, all_signatures, message, slot)?.0) } fn xmss_aggregate_signatures_helper( xmss_pub_keys: &[XmssPublicKey], all_signatures: &[XmssSignature], - message_hash: [F; 8], - slot: u64, + message: [F; MESSAGE_LEN_FE], + slot: u32, ) -> Result<(Vec, usize, String), XmssAggregateError> { if xmss_pub_keys.len() != all_signatures.len() { return Err(XmssAggregateError::WrongSignatureCount); @@ -133,10 +143,10 @@ fn xmss_aggregate_signatures_helper( let program = get_xmss_aggregation_program(); - let poseidons_16_precomputed = precompute_poseidons(xmss_pub_keys, all_signatures, &message_hash) - .ok_or(XmssAggregateError::InvalidSigature)?; + let poseidons_16_precomputed = + precompute_poseidons(xmss_pub_keys, all_signatures, &message).ok_or(XmssAggregateError::InvalidSigature)?; - let public_input = build_public_input(xmss_pub_keys, message_hash, slot); + let public_input = build_public_input(xmss_pub_keys, message, slot); let private_input = build_private_input(all_signatures); let proof = prove_execution( @@ -154,18 +164,17 @@ fn xmss_aggregate_signatures_helper( pub fn xmss_verify_aggregated_signatures( xmss_pub_keys: &[XmssPublicKey], - message_hash: [F; 8], + message: [F; MESSAGE_LEN_FE], proof_bytes: &[u8], - slot: u64, + slot: u32, ) -> Result<(), ProofError> { - let _ = slot; // TODO let program = get_xmss_aggregation_program(); let proof = info_span!("Proof deserialization") .in_scope(|| bincode::deserialize(proof_bytes)) .map_err(|_| ProofError::InvalidProof)?; - let public_input = build_public_input(xmss_pub_keys, message_hash, slot); + let public_input = build_public_input(xmss_pub_keys, message, slot); verify_execution(program, &public_input, proof, &default_whir_config(LOG_INV_RATE)).map(|_| ()) } @@ -174,13 +183,13 @@ pub fn xmss_verify_aggregated_signatures( fn precompute_poseidons( xmss_pub_keys: &[XmssPublicKey], all_signatures: &[XmssSignature], - message_hash: &[F; 8], + message: &[F; MESSAGE_LEN_FE], ) -> Option { assert_eq!(xmss_pub_keys.len(), all_signatures.len()); let traces = xmss_pub_keys .par_iter() .zip(all_signatures.par_iter()) - .map(|(pub_key, sig)| xmss_verify_with_poseidon_trace(pub_key, message_hash, sig)) + .map(|(pub_key, sig)| xmss_verify_with_poseidon_trace(pub_key, message, sig)) .collect::, _>>() .ok()?; Some(traces.into_par_iter().flatten().collect()) @@ -188,10 +197,5 @@ fn precompute_poseidons( #[test] fn test_xmss_aggregate() { - let n_xmss = 10; - let mut rng = StdRng::seed_from_u64(0); - let log_lifetimes = (0..n_xmss) - .map(|_| rng.random_range(xmss::XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME)) - .collect::>(); - run_xmss_benchmark(&log_lifetimes, false); + run_xmss_benchmark(10, false); } diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 56111ab9..cd671059 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -1,13 +1,15 @@ from snark_lib import * -V = 66 -W = 4 -TARGET_SUM = 118 -MAX_LOG_LIFETIME = 30 -VECTOR_LEN = 8 +V = V_PLACEHOLDER +W = W_PLACEHOLDER +TARGET_SUM = TARGET_SUM_PLACEHOLDER +LOG_LIFETIME = LOG_LIFETIME_PLACEHOLDER +DIGEST_LEN = 8 +RANDOMNESS_LEN = 8 +SIG_SIZE = RANDOMNESS_LEN + (V + LOG_LIFETIME) * DIGEST_LEN # Dot product precompile: -BE = 1 # base-extension +BE = 1 # base-extension (unused for XMSS) EE = 0 # extension-extension @@ -16,13 +18,13 @@ def main(): signatures_start = pub_mem[0] n_signatures = pub_mem[1] message_hash = pub_mem + 2 - all_public_keys = message_hash + VECTOR_LEN - all_log_lifetimes = all_public_keys + n_signatures * VECTOR_LEN + all_public_keys = message_hash + DIGEST_LEN + all_log_lifetimes = all_public_keys + n_signatures * DIGEST_LEN all_merkle_indexes = all_log_lifetimes + n_signatures sig_sizes = all_merkle_indexes + n_signatures * MAX_LOG_LIFETIME for i in range(0, n_signatures): - xmss_public_key = all_public_keys + i * VECTOR_LEN + xmss_public_key = all_public_keys + i * DIGEST_LEN signature = signatures_start + sig_sizes[i] log_lifetime = all_log_lifetimes[i] merkle_index = all_merkle_indexes + i * MAX_LOG_LIFETIME @@ -35,12 +37,12 @@ def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): # signature: randomness | chain_tips # return the hashed xmss public key randomness = signature - chain_tips = signature + VECTOR_LEN - merkle_path = chain_tips + V * VECTOR_LEN + chain_tips = signature + DIGEST_LEN + merkle_path = chain_tips + V * DIGEST_LEN # 1) We encode message_hash + randomness into the d-th layer of the hypercube - compressed = Array(VECTOR_LEN) + compressed = Array(DIGEST_LEN) poseidon16(message_hash, randomness, compressed) compressed_vals = Array(6) dot_product(compressed, ONE_VEC_PTR, compressed_vals, 1, EE) @@ -78,35 +80,35 @@ def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): target_sum += encoding[i] assert target_sum == TARGET_SUM - public_key = Array(V * VECTOR_LEN) + public_key = Array(V * DIGEST_LEN) # This is a trick to avoid the compiler to allocate memory "on stack". # (Heap allocation is better here, to keep the memmory use of the different "match arms" balanced) - vector_len = VECTOR_LEN + digest_len = DIGEST_LEN for i in unroll(0, V): match encoding[i]: case 0: - var_1 = chain_tips + i * VECTOR_LEN - var_2 = public_key + i * VECTOR_LEN - var_3 = Array(vector_len) - var_4 = Array(vector_len) + var_1 = chain_tips + i * DIGEST_LEN + var_2 = public_key + i * DIGEST_LEN + var_3 = Array(digest_len) + var_4 = Array(digest_len) poseidon16(var_1, ZERO_VEC_PTR, var_3) poseidon16(var_3, ZERO_VEC_PTR, var_4) poseidon16(var_4, ZERO_VEC_PTR, var_2) case 1: - var_3 = Array(vector_len) - var_1 = chain_tips + i * VECTOR_LEN - var_2 = public_key + i * VECTOR_LEN + var_3 = Array(digest_len) + var_1 = chain_tips + i * DIGEST_LEN + var_2 = public_key + i * DIGEST_LEN poseidon16(var_1, ZERO_VEC_PTR, var_3) poseidon16(var_3, ZERO_VEC_PTR, var_2) case 2: - var_1 = chain_tips + i * VECTOR_LEN - var_2 = public_key + i * VECTOR_LEN + var_1 = chain_tips + i * DIGEST_LEN + var_2 = public_key + i * DIGEST_LEN poseidon16(var_1, ZERO_VEC_PTR, var_2) case 3: - var_1 = chain_tips + (i * VECTOR_LEN) - var_2 = public_key + (i * VECTOR_LEN) + var_1 = chain_tips + (i * digest_len) + var_2 = public_key + (i * digest_len) var_3 = var_1 + 3 var_4 = var_2 + 3 dot_product(var_1, ONE_VEC_PTR, var_2, 1, EE) @@ -189,7 +191,7 @@ def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): - states = Array(height * VECTOR_LEN) + states = Array(height * DIGEST_LEN) # First merkle round match leaf_position_bits[0]: @@ -202,18 +204,18 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): state_indexes = Array(height) state_indexes[0] = states for j in unroll(1, height): - state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN + state_indexes[j] = state_indexes[j - 1] + DIGEST_LEN # Warning: this works only if leaf_position_bits[i] is known to be boolean: match leaf_position_bits[j]: case 0: poseidon16( state_indexes[j - 1], - merkle_path + j * VECTOR_LEN, + merkle_path + j * DIGEST_LEN, state_indexes[j], ) case 1: poseidon16( - merkle_path + j * VECTOR_LEN, + merkle_path + j * DIGEST_LEN, state_indexes[j - 1], state_indexes[j], ) @@ -221,13 +223,13 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): def slice_hash(data, len: Const): - states = Array(len * VECTOR_LEN) + states = Array(len * DIGEST_LEN) poseidon16(ZERO_VEC_PTR, data, states) state_indexes = Array(len) state_indexes[0] = states for j in unroll(1, (len)): - state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN - poseidon16(state_indexes[j - 1], data + j * VECTOR_LEN, state_indexes[j]) + state_indexes[j] = state_indexes[j - 1] + DIGEST_LEN + poseidon16(state_indexes[j - 1], data + j * DIGEST_LEN, state_indexes[j]) return state_indexes[len - 1] diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index b90abd67..f0706d89 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -24,6 +24,10 @@ pub const NUM_CHAIN_HASHES: usize = 80; pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; pub const V_GRINDING: usize = 1; pub const LOG_LIFETIME: usize = 32; +pub const RANDOMNESS_LEN_FE: usize = 8; +pub const MESSAGE_LEN_FE: usize = 9; + +pub const SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + (V + LOG_LIFETIME) * DIGEST_SIZE; pub type Poseidon16History = Vec<([F; 16], [F; 8])>; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index e8608dfe..0b5314f4 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -16,7 +16,7 @@ pub struct WotsPublicKey(pub [Digest; V]); #[derive(Debug)] pub struct WotsSignature { pub chain_tips: [Digest; V], - pub randomness: Digest, + pub randomness: [F; RANDOMNESS_LEN_FE], } impl WotsSecretKey { @@ -37,12 +37,12 @@ impl WotsSecretKey { pub fn sign( &self, - message: &[F; 9], - epoch: u32, + message: &[F; MESSAGE_LEN_FE], + slot: u32, truncated_merkle_root: &[F; 5], rng: &mut impl Rng, ) -> WotsSignature { - let (randomness, encoding, _) = find_randomness_for_wots_encoding(message, epoch, truncated_merkle_root, rng); + let (randomness, encoding, _) = find_randomness_for_wots_encoding(message, slot, truncated_merkle_root, rng); WotsSignature { chain_tips: std::array::from_fn(|i| iterate_hash(&self.pre_images[i], encoding[i] as usize)), randomness, @@ -53,31 +53,25 @@ impl WotsSecretKey { impl WotsSignature { pub fn recover_public_key( &self, - message: &[F; 9], - epoch: u32, + message: &[F; MESSAGE_LEN_FE], + slot: u32, truncated_merkle_root: &[F; 5], signature: &Self, ) -> Option { - self.recover_public_key_with_poseidon_trace( - message, - epoch, - truncated_merkle_root, - signature, - &mut Vec::new(), - ) + self.recover_public_key_with_poseidon_trace(message, slot, truncated_merkle_root, signature, &mut Vec::new()) } pub fn recover_public_key_with_poseidon_trace( &self, - message: &[F; 9], - epoch: u32, + message: &[F; MESSAGE_LEN_FE], + slot: u32, truncated_merkle_root: &[F; 5], signature: &Self, poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>, ) -> Option { let encoding = wots_encode_with_poseidon_trace( message, - epoch, + slot, truncated_merkle_root, &signature.randomness, poseidon_16_trace, @@ -121,46 +115,45 @@ pub fn iterate_hash_with_poseidon_trace( } pub fn find_randomness_for_wots_encoding( - message: &[F; 9], - epoch: u32, + message: &[F; MESSAGE_LEN_FE], + slot: u32, truncated_merkle_root: &[F; 5], rng: &mut impl Rng, -) -> (Digest, [u8; V], usize) { +) -> ([F; RANDOMNESS_LEN_FE], [u8; V], usize) { let mut num_iters = 0; loop { num_iters += 1; let randomness = rng.random(); - if let Some(encoding) = wots_encode(message, epoch, truncated_merkle_root, &randomness) { + if let Some(encoding) = wots_encode(message, slot, truncated_merkle_root, &randomness) { return (randomness, encoding, num_iters); } } } pub fn wots_encode( - message: &[F; 9], - epoch: u32, + message: &[F; MESSAGE_LEN_FE], + slot: u32, truncated_merkle_root: &[F; 5], - randomness: &Digest, + randomness: &[F; RANDOMNESS_LEN_FE], ) -> Option<[u8; V]> { - wots_encode_with_poseidon_trace(message, epoch, truncated_merkle_root, randomness, &mut Vec::new()) + wots_encode_with_poseidon_trace(message, slot, truncated_merkle_root, randomness, &mut Vec::new()) } pub fn wots_encode_with_poseidon_trace( - message: &[F; 9], - epoch: u32, + message: &[F; MESSAGE_LEN_FE], + slot: u32, truncated_merkle_root: &[F; 5], - randomness: &Digest, + randomness: &[F; RANDOMNESS_LEN_FE], poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>, ) -> Option<[u8; V]> { - // Encode epoch as 2 field elements (16 bits each) - let epoch_lo = F::from_usize((epoch & 0xFFFF) as usize); - let epoch_hi = F::from_usize(((epoch >> 16) & 0xFFFF) as usize); + // Encode slot as 2 field elements (16 bits each) + let [slot_lo, slot_hi] = slot_to_field_elements(slot); - // A = poseidon(message (9 fe), epoch (2 fe), truncated_merkle_root (5 fe)) + // A = poseidon(message (9 fe), slot (2 fe), truncated_merkle_root (5 fe)) let mut a_input_right = [F::default(); 8]; a_input_right[0] = message[8]; - a_input_right[1] = epoch_lo; - a_input_right[2] = epoch_hi; + a_input_right[1] = slot_lo; + a_input_right[2] = slot_hi; a_input_right[3..8].copy_from_slice(truncated_merkle_root); let a = poseidon16_compress_with_trace(message[..8].try_into().unwrap(), &a_input_right, poseidon_16_trace); @@ -204,3 +197,10 @@ fn is_valid_encoding(encoding: &[u8]) -> bool { } true } + +pub fn slot_to_field_elements(slot: u32) -> [F; 2] { + [ + F::from_usize((slot & 0xFFFF) as usize), + F::from_usize(((slot >> 16) & 0xFFFF) as usize), + ] +} diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 5d743127..4087e3de 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -112,10 +112,10 @@ pub enum XmssSignatureError { SlotOutOfRange, } -pub fn xmss_sign( - randomness_seed: [u8; 32], +pub fn xmss_sign( + rng: &mut R, secret_key: &XmssSecretKey, - message_hash: &[F; 9], + message_hash: &[F; MESSAGE_LEN_FE], slot: u32, ) -> Result { if slot < secret_key.start || slot > secret_key.end { @@ -124,8 +124,7 @@ pub fn xmss_sign( let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot as u64); let merkle_root = secret_key.public_key().merkle_root; let truncated_merkle_root: [F; 5] = merkle_root[0..5].try_into().unwrap(); - let mut rng = StdRng::from_seed(randomness_seed); - let wots_signature = wots_secret_key.sign(message_hash, slot, &truncated_merkle_root, &mut rng); + let wots_signature = wots_secret_key.sign(message_hash, slot, &truncated_merkle_root, rng); let merkle_proof = (0..LOG_LIFETIME) .map(|level| { let neighbour_index = ((slot as u64) >> level) ^ 1; @@ -161,7 +160,7 @@ pub enum XmssVerifyError { pub fn xmss_verify( pub_key: &XmssPublicKey, - message: &[F; 9], + message: &[F; MESSAGE_LEN_FE], signature: &XmssSignature, ) -> Result<(), XmssVerifyError> { xmss_verify_with_poseidon_trace(pub_key, message, signature).map(|_| ()) @@ -169,7 +168,7 @@ pub fn xmss_verify( pub fn xmss_verify_with_poseidon_trace( pub_key: &XmssPublicKey, - message: &[F; 9], + message: &[F; MESSAGE_LEN_FE], signature: &XmssSignature, ) -> Result { let mut poseidon_16_trace = Vec::new(); @@ -188,9 +187,8 @@ pub fn xmss_verify_with_poseidon_trace( if signature.merkle_proof.len() != LOG_LIFETIME { return Err(XmssVerifyError::InvalidMerklePath); } - let wots_index = signature.slot as u64; for (level, neighbour) in signature.merkle_proof.iter().enumerate() { - let is_left = ((wots_index >> level) & 1) == 0; + let is_left = (((signature.slot as u64) >> level) & 1) == 0; if is_left { current_hash = poseidon16_compress_with_trace(¤t_hash, neighbour, &mut poseidon_16_trace); } else { diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index ff041899..d5cb7ee7 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -8,12 +8,11 @@ type F = KoalaBear; #[test] fn keygen_sign_verify() { let keygen_seed: [u8; 32] = std::array::from_fn(|i| i as u8); - let randomness_seed: [u8; 32] = std::array::from_fn(|i| (i * 2 + 1) as u8); - let message: [F; 9] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); + let message: [F; MESSAGE_LEN_FE] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); let (sk, pk) = xmss_key_gen(keygen_seed, 100, 115).unwrap(); for slot in 100..=115 { - let sig = xmss_sign(randomness_seed, &sk, &message, slot).unwrap(); + let sig = xmss_sign(&mut StdRng::seed_from_u64(slot as u64), &sk, &message, slot).unwrap(); xmss_verify(&pk, &message, &sig).unwrap(); } } @@ -25,12 +24,12 @@ fn encoding_grinding_bits() { let total_iters = (0..n) .into_par_iter() .map(|i| { - let message: [F; 9] = Default::default(); - let epoch = i as u32; + let message: [F; MESSAGE_LEN_FE] = Default::default(); + let slot = i as u32; let truncated_merkle_root: [F; 5] = Default::default(); let mut rng = StdRng::seed_from_u64(i as u64); let (_randomness, _encoding, num_iters) = - find_randomness_for_wots_encoding(&message, epoch, &truncated_merkle_root, &mut rng); + find_randomness_for_wots_encoding(&message, slot, &truncated_merkle_root, &mut rng); num_iters }) .sum::(); From e9980424316e3371e50fdbb3f6bda1056c166e12 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 14:25:23 +0400 Subject: [PATCH 07/29] wip --- crates/lean_compiler/snark_lib.py | 3 + crates/lean_vm/src/isa/hint.rs | 23 +- crates/rec_aggregation/src/xmss_aggregate.rs | 11 +- crates/rec_aggregation/xmss_aggregate.py | 242 +++++++------------ crates/xmss/src/lib.rs | 3 +- crates/xmss/src/wots.rs | 27 ++- crates/xmss/src/xmss.rs | 4 +- crates/xmss/tests/xmss_tests.rs | 2 +- 8 files changed, 133 insertions(+), 182 deletions(-) diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index b09063ac..5dbb791c 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -120,3 +120,6 @@ def match_range(value: int, *args): if value in rng: return fn(value) raise AssertionError(f"Value {value} not in any range") + +def hint_decompose_bits_xmss(*args): + _ = args \ No newline at end of file diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 603445ef..ea473cbb 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -90,7 +90,7 @@ impl CustomHint { pub fn n_args_range(&self) -> Range { match self { - Self::DecomposeBitsXMSS => 3..usize::MAX, + Self::DecomposeBitsXMSS => 5..6, Self::DecomposeBits => 4..5, Self::LessThan => 3..4, } @@ -99,15 +99,18 @@ impl CustomHint { pub fn execute(&self, args: &[MemOrConstant], ctx: &mut HintExecutionContext<'_>) -> Result<(), RunnerError> { match self { Self::DecomposeBitsXMSS => { - let decomposed = &args[0]; - let remaining = &args[1]; - let to_decompose = &args[2..]; - let mut memory_index_decomposed = decomposed.read_value(ctx.memory, ctx.fp)?.to_usize(); - let mut memory_index_remaining = remaining.read_value(ctx.memory, ctx.fp)?.to_usize(); - for value_source in to_decompose { - let value = value_source.read_value(ctx.memory, ctx.fp)?.to_usize(); - for i in 0..12 { - let value = F::from_usize((value >> (2 * i)) & 0b11); + let decomposed_ptr = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); + let remaining_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); + let to_decompose_ptr = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); + let num_to_decompose = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); + let w = args[4].read_value(ctx.memory, ctx.fp)?.to_usize(); + assert!(w == 2 || w == 3 || w == 4); + let mut memory_index_decomposed = decomposed_ptr; + let mut memory_index_remaining = remaining_ptr; + for i in 0..num_to_decompose { + let value = ctx.memory.get(to_decompose_ptr + i)?.to_usize(); + for i in 0..24 / w { + let value = F::from_usize((value >> (w * i)) & 0b11); ctx.memory.set(memory_index_decomposed, value)?; memory_index_decomposed += 1; } diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 6e4abb9b..1111ba3a 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -8,8 +8,7 @@ use std::time::Instant; use std::{collections::BTreeMap, path::Path}; use tracing::{info_span, instrument}; use xmss::{ - LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, SIG_SIZE_FE, TARGET_SUM, V, W, XmssPublicKey, XmssSignature, - slot_to_field_elements, xmss_verify_with_poseidon_trace, + LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, SIG_SIZE_FE, TARGET_SUM, V, V_GRINDING, W, XmssPublicKey, XmssSignature, slot_to_field_elements, xmss_verify_with_poseidon_trace }; static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); @@ -72,9 +71,11 @@ fn compile_xmss_aggregation_program() -> Bytecode { .to_string(); let mut replacements = BTreeMap::new(); replacements.insert("V_PLACEHOLDER".to_string(), V.to_string()); + replacements.insert("V_GRINDING_PLACEHOLDER".to_string(), V_GRINDING.to_string()); replacements.insert("W_PLACEHOLDER".to_string(), W.to_string()); replacements.insert("TARGET_SUM_PLACEHOLDER".to_string(), TARGET_SUM.to_string()); replacements.insert("LOG_LIFETIME_PLACEHOLDER".to_string(), LOG_LIFETIME.to_string()); + replacements.insert("MESSAGE_LEN_PLACEHOLDER".to_string(), MESSAGE_LEN_FE.to_string()); compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements }) } @@ -85,7 +86,11 @@ pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { xmss_setup_aggregation_program(); precompute_dft_twiddles::(1 << 24); - let message: [F; MESSAGE_LEN_FE] = (0..9).map(|i| F::from_usize(i)).collect::>().try_into().unwrap(); + let message = (0..MESSAGE_LEN_FE) + .map(|i| F::from_usize(i)) + .collect::>() + .try_into() + .unwrap(); let slot = 1111; let pub_keys_and_sigs = (0..n_signatures) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index cd671059..344e05a1 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -1,14 +1,19 @@ from snark_lib import * V = V_PLACEHOLDER +V_GRINDING = V_GRINDING_PLACEHOLDER W = W_PLACEHOLDER +CHAIN_LENGTH = 2 ** W TARGET_SUM = TARGET_SUM_PLACEHOLDER LOG_LIFETIME = LOG_LIFETIME_PLACEHOLDER +MESSAGE_LEN = MESSAGE_LEN_PLACEHOLDER DIGEST_LEN = 8 RANDOMNESS_LEN = 8 SIG_SIZE = RANDOMNESS_LEN + (V + LOG_LIFETIME) * DIGEST_LEN +NUM_ENCODING_FE = div_ceil((V + V_GRINDING), (24 / W)) # 24 should be divisible by W (works for W=2,3,4) # Dot product precompile: +DIM = 5 BE = 1 # base-extension (unused for XMSS) EE = 0 # extension-extension @@ -17,62 +22,62 @@ def main(): pub_mem = NONRESERVED_PROGRAM_INPUT_START signatures_start = pub_mem[0] n_signatures = pub_mem[1] - message_hash = pub_mem + 2 - all_public_keys = message_hash + DIGEST_LEN - all_log_lifetimes = all_public_keys + n_signatures * DIGEST_LEN - all_merkle_indexes = all_log_lifetimes + n_signatures - sig_sizes = all_merkle_indexes + n_signatures * MAX_LOG_LIFETIME + message = pub_mem + 2 + slot_ptr = message + MESSAGE_LEN + slot_lo = slot_ptr[0] + slot_hi = slot_ptr[1] + merkle_indexes = slot_ptr + 2 # is left ? + all_public_keys = merkle_indexes + LOG_LIFETIME for i in range(0, n_signatures): - xmss_public_key = all_public_keys + i * DIGEST_LEN - signature = signatures_start + sig_sizes[i] - log_lifetime = all_log_lifetimes[i] - merkle_index = all_merkle_indexes + i * MAX_LOG_LIFETIME - xmss_public_key_recovered = xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index) - assert_eq_vec(xmss_public_key, xmss_public_key_recovered) + merkle_root = all_public_keys + i * DIGEST_LEN + signature = signatures_start + SIG_SIZE * i + xmss_verify(merkle_root, message, signature, slot_lo, slot_hi, merkle_indexes) return -def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): - # signature: randomness | chain_tips +def xmss_verify(merkle_root, message, signature, slot_lo, slot_hi, merkle_indexes): + # signature: randomness | chain_tips | merkle_path # return the hashed xmss public key randomness = signature - chain_tips = signature + DIGEST_LEN + chain_tips = signature + RANDOMNESS_LEN merkle_path = chain_tips + V * DIGEST_LEN - # 1) We encode message_hash + randomness into the d-th layer of the hypercube + # 1) We encode message_hash + randomness into the layer of the hypercube with target sum = TARGET_SUM - compressed = Array(DIGEST_LEN) - poseidon16(message_hash, randomness, compressed) - compressed_vals = Array(6) - dot_product(compressed, ONE_VEC_PTR, compressed_vals, 1, EE) - compressed_vals[5] = compressed[5] + a_input_right = Array(DIGEST_LEN) + b_input = Array(DIGEST_LEN * 2) + a_input_right[0] = message[8] + copy_7(randomness, a_input_right + 1) + poseidon16(message, a_input_right, b_input) + b_input[8] = slot_lo + b_input[9] = slot_hi + copy_6(merkle_root, b_input + 10) + encoding_fe = Array(DIGEST_LEN) + poseidon16(b_input, b_input + DIGEST_LEN, encoding_fe) - encoding = Array(12 * 6) - remaining = Array(6) + encoding = Array(NUM_ENCODING_FE * 24 / W) + remaining = Array(NUM_ENCODING_FE) hint_decompose_bits_xmss( encoding, remaining, - compressed_vals[0], - compressed_vals[1], - compressed_vals[2], - compressed_vals[3], - compressed_vals[4], - compressed_vals[5], + encoding_fe, + NUM_ENCODING_FE, + W ) # check that the decomposition is correct - for i in unroll(0, 6): - for j in unroll(0, 12): - assert encoding[i * 12 + j] <= 3 + for i in unroll(0, NUM_ENCODING_FE): + for j in unroll(0, 24 / W): + assert encoding[i * (24 / W) + j] < CHAIN_LENGTH - assert remaining[i] <= 2**7 - 2 + assert remaining[i] < 2**7 - 1 partial_sum: Mut = remaining[i] * 2**24 - for j in unroll(1, 13): - partial_sum += encoding[i * 12 + (j - 1)] * 4 ** (j - 1) - assert partial_sum == compressed_vals[i] + for j in unroll(0, 24/W): + partial_sum += encoding[i * (24 / W) + j] * CHAIN_LENGTH ** j + assert partial_sum == encoding_fe[i] # we need to check the target sum target_sum: Mut = encoding[0] @@ -80,117 +85,36 @@ def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): target_sum += encoding[i] assert target_sum == TARGET_SUM - public_key = Array(V * DIGEST_LEN) + # grinding + for i in unroll(V, V + V_GRINDING): + assert encoding[i] == CHAIN_LENGTH - 1 - # This is a trick to avoid the compiler to allocate memory "on stack". - # (Heap allocation is better here, to keep the memmory use of the different "match arms" balanced) - digest_len = DIGEST_LEN + wots_public_key = Array(V * DIGEST_LEN) for i in unroll(0, V): - match encoding[i]: - case 0: - var_1 = chain_tips + i * DIGEST_LEN - var_2 = public_key + i * DIGEST_LEN - var_3 = Array(digest_len) - var_4 = Array(digest_len) - poseidon16(var_1, ZERO_VEC_PTR, var_3) - poseidon16(var_3, ZERO_VEC_PTR, var_4) - poseidon16(var_4, ZERO_VEC_PTR, var_2) - case 1: - var_3 = Array(digest_len) - var_1 = chain_tips + i * DIGEST_LEN - var_2 = public_key + i * DIGEST_LEN - poseidon16(var_1, ZERO_VEC_PTR, var_3) - poseidon16(var_3, ZERO_VEC_PTR, var_2) - case 2: - var_1 = chain_tips + i * DIGEST_LEN - var_2 = public_key + i * DIGEST_LEN - poseidon16(var_1, ZERO_VEC_PTR, var_2) - case 3: - var_1 = chain_tips + (i * digest_len) - var_2 = public_key + (i * digest_len) - var_3 = var_1 + 3 - var_4 = var_2 + 3 - dot_product(var_1, ONE_VEC_PTR, var_2, 1, EE) - dot_product(var_3, ONE_VEC_PTR, var_4, 1, EE) - - wots_pubkey_hashed = slice_hash(public_key, V) - - debug_assert(log_lifetime < MAX_LOG_LIFETIME + 1) - - merkle_root: Imu - match log_lifetime: - case 0: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 0) - case 1: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 1) - case 2: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 2) - case 3: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 3) - case 4: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 4) - case 5: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 5) - case 6: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 6) - case 7: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 7) - case 8: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 8) - case 9: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 9) - case 10: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 10) - case 11: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 11) - case 12: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 12) - case 13: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 13) - case 14: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 14) - case 15: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 15) - case 16: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 16) - case 17: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 17) - case 18: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 18) - case 19: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 19) - case 20: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 20) - case 21: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 21) - case 22: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 22) - case 23: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 23) - case 24: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 24) - case 25: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 25) - case 26: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 26) - case 27: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 27) - case 28: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 28) - case 29: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 29) - case 30: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 30) - case 31: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 31) - case 32: - merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 32) - - return merkle_root - - -def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): + num_hashes = (CHAIN_LENGTH - 1) - encoding[i] + match_range(num_hashes, + range(0, 1), lambda _: copy_8(chain_tips + i * DIGEST_LEN, wots_public_key + i * DIGEST_LEN), + range(1, 2), lambda _: poseidon16(chain_tips + i * DIGEST_LEN, ZERO_VEC_PTR, wots_public_key + i * DIGEST_LEN), + range(0, CHAIN_LENGTH), lambda i: chain_hash(chain_tips + i * DIGEST_LEN, i, wots_public_key + i * DIGEST_LEN)) + + wots_pubkey_hashed = slice_hash(wots_public_key, V) + merkle_root_recovered = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_indexes, LOG_LIFETIME) + copy_8(merkle_root, merkle_root_recovered) + + +def chain_hash(input, n: Const, output): + debug_assert(2 < n) + states = Array((n-1) * DIGEST_LEN) + poseidon16(input, ZERO_VEC_PTR, states) + for i in unroll(1, n-1): + poseidon16(states + (i - 1) * DIGEST_LEN, ZERO_VEC_PTR, states + i * DIGEST_LEN) + poseidon16(states + (n - 2) * DIGEST_LEN, ZERO_VEC_PTR, output) + return + + +@inline +def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): states = Array(height * DIGEST_LEN) # First merkle round @@ -223,18 +147,30 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): def slice_hash(data, len: Const): - states = Array(len * DIGEST_LEN) - poseidon16(ZERO_VEC_PTR, data, states) - state_indexes = Array(len) - state_indexes[0] = states - for j in unroll(1, (len)): - state_indexes[j] = state_indexes[j - 1] + DIGEST_LEN - poseidon16(state_indexes[j - 1], data + j * DIGEST_LEN, state_indexes[j]) - return state_indexes[len - 1] + states = Array((len-1) * DIGEST_LEN) + poseidon16(data, data + DIGEST_LEN, states) + for i in unroll(1, len-1): + poseidon16(states + (i - 1) * DIGEST_LEN, data + DIGEST_LEN * (i + 1), states + i * DIGEST_LEN) + return data + + +@inline +def copy_8(x, y): + dot_product(x, ONE_VEC_PTR, y, 1, EE) + dot_product(x + (8 - DIM), ONE_VEC_PTR, y + (8 - DIM), 1, EE) + return + + +@inline +def copy_7(x, y): + dot_product(x, ONE_VEC_PTR, y, 1, EE) + x[5] = y[5] + y[6] = x[6] + return @inline -def assert_eq_vec(x, y): +def copy_6(x, y): dot_product(x, ONE_VEC_PTR, y, 1, EE) - dot_product(x + 3, ONE_VEC_PTR, y + 3, 1, EE) + x[5] = y[5] return diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index f0706d89..3cedc895 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -24,8 +24,9 @@ pub const NUM_CHAIN_HASHES: usize = 80; pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; pub const V_GRINDING: usize = 1; pub const LOG_LIFETIME: usize = 32; -pub const RANDOMNESS_LEN_FE: usize = 8; +pub const RANDOMNESS_LEN_FE: usize = 7; pub const MESSAGE_LEN_FE: usize = 9; +pub const TRUNCATED_MERKLE_ROOT_LEN_FE: usize = 6; pub const SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + (V + LOG_LIFETIME) * DIGEST_SIZE; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 0b5314f4..b2cb667f 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -39,7 +39,7 @@ impl WotsSecretKey { &self, message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; 5], + truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], rng: &mut impl Rng, ) -> WotsSignature { let (randomness, encoding, _) = find_randomness_for_wots_encoding(message, slot, truncated_merkle_root, rng); @@ -55,7 +55,7 @@ impl WotsSignature { &self, message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; 5], + truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], signature: &Self, ) -> Option { self.recover_public_key_with_poseidon_trace(message, slot, truncated_merkle_root, signature, &mut Vec::new()) @@ -65,7 +65,7 @@ impl WotsSignature { &self, message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; 5], + truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], signature: &Self, poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>, ) -> Option { @@ -117,7 +117,7 @@ pub fn iterate_hash_with_poseidon_trace( pub fn find_randomness_for_wots_encoding( message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; 5], + truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], rng: &mut impl Rng, ) -> ([F; RANDOMNESS_LEN_FE], [u8; V], usize) { let mut num_iters = 0; @@ -133,7 +133,7 @@ pub fn find_randomness_for_wots_encoding( pub fn wots_encode( message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; 5], + truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], randomness: &[F; RANDOMNESS_LEN_FE], ) -> Option<[u8; V]> { wots_encode_with_poseidon_trace(message, slot, truncated_merkle_root, randomness, &mut Vec::new()) @@ -142,25 +142,28 @@ pub fn wots_encode( pub fn wots_encode_with_poseidon_trace( message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; 5], + truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], randomness: &[F; RANDOMNESS_LEN_FE], poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>, ) -> Option<[u8; V]> { // Encode slot as 2 field elements (16 bits each) let [slot_lo, slot_hi] = slot_to_field_elements(slot); - // A = poseidon(message (9 fe), slot (2 fe), truncated_merkle_root (5 fe)) + // A = poseidon(message (9 fe), randomness (7 fe)) let mut a_input_right = [F::default(); 8]; a_input_right[0] = message[8]; - a_input_right[1] = slot_lo; - a_input_right[2] = slot_hi; - a_input_right[3..8].copy_from_slice(truncated_merkle_root); + a_input_right[1..1 + RANDOMNESS_LEN_FE].copy_from_slice(randomness); let a = poseidon16_compress_with_trace(message[..8].try_into().unwrap(), &a_input_right, poseidon_16_trace); - // B = poseidon(A (8 fe), randomness (8 fe)) - let compressed = poseidon16_compress_with_trace(&a, randomness, poseidon_16_trace); + // B = poseidon(A (8 fe), slot (2 fe), truncated_merkle_root (6 fe)) + let mut b_input_right = [F::default(); 8]; + b_input_right[0] = slot_lo; + b_input_right[1] = slot_hi; + b_input_right[2..8].copy_from_slice(truncated_merkle_root); + let compressed = poseidon16_compress_with_trace(&a, &b_input_right, poseidon_16_trace); if compressed.iter().any(|&kb| kb == -F::ONE) { + // ensures uniformity of encoding return None; } let all_indices: Vec<_> = compressed diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 4087e3de..07cb86a2 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -123,7 +123,7 @@ pub fn xmss_sign( } let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot as u64); let merkle_root = secret_key.public_key().merkle_root; - let truncated_merkle_root: [F; 5] = merkle_root[0..5].try_into().unwrap(); + let truncated_merkle_root: [F; TRUNCATED_MERKLE_ROOT_LEN_FE] = merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); let wots_signature = wots_secret_key.sign(message_hash, slot, &truncated_merkle_root, rng); let merkle_proof = (0..LOG_LIFETIME) .map(|level| { @@ -172,7 +172,7 @@ pub fn xmss_verify_with_poseidon_trace( signature: &XmssSignature, ) -> Result { let mut poseidon_16_trace = Vec::new(); - let truncated_merkle_root: [F; 5] = pub_key.merkle_root[0..5].try_into().unwrap(); + let truncated_merkle_root: [F; TRUNCATED_MERKLE_ROOT_LEN_FE] = pub_key.merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); let wots_public_key = signature .wots_signature .recover_public_key_with_poseidon_trace( diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index d5cb7ee7..3e41661b 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -26,7 +26,7 @@ fn encoding_grinding_bits() { .map(|i| { let message: [F; MESSAGE_LEN_FE] = Default::default(); let slot = i as u32; - let truncated_merkle_root: [F; 5] = Default::default(); + let truncated_merkle_root: [F; TRUNCATED_MERKLE_ROOT_LEN_FE] = Default::default(); let mut rng = StdRng::seed_from_u64(i as u64); let (_randomness, _encoding, num_iters) = find_randomness_for_wots_encoding(&message, slot, &truncated_merkle_root, &mut rng); From 4eb069c3b55074bb09bed10307886e6f2b173816 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 14:27:44 +0400 Subject: [PATCH 08/29] wip --- crates/rec_aggregation/src/xmss_aggregate.rs | 2 +- crates/rec_aggregation/xmss_aggregate.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 1111ba3a..3b180c92 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -202,5 +202,5 @@ fn precompute_poseidons( #[test] fn test_xmss_aggregate() { - run_xmss_benchmark(10, false); + run_xmss_benchmark(1, false); } diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 344e05a1..9c894354 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -96,11 +96,12 @@ def xmss_verify(merkle_root, message, signature, slot_lo, slot_hi, merkle_indexe match_range(num_hashes, range(0, 1), lambda _: copy_8(chain_tips + i * DIGEST_LEN, wots_public_key + i * DIGEST_LEN), range(1, 2), lambda _: poseidon16(chain_tips + i * DIGEST_LEN, ZERO_VEC_PTR, wots_public_key + i * DIGEST_LEN), - range(0, CHAIN_LENGTH), lambda i: chain_hash(chain_tips + i * DIGEST_LEN, i, wots_public_key + i * DIGEST_LEN)) + range(2, CHAIN_LENGTH), lambda num_hashes_const: chain_hash(chain_tips + i * DIGEST_LEN, num_hashes_const, wots_public_key + i * DIGEST_LEN)) wots_pubkey_hashed = slice_hash(wots_public_key, V) merkle_root_recovered = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_indexes, LOG_LIFETIME) copy_8(merkle_root, merkle_root_recovered) + return def chain_hash(input, n: Const, output): From 6000be6e59cf0517763c8331c96a0dee3bcf5113 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 14:32:16 +0400 Subject: [PATCH 09/29] wip --- crates/rec_aggregation/xmss_aggregate.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 9c894354..467ae917 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -105,7 +105,7 @@ def xmss_verify(merkle_root, message, signature, slot_lo, slot_hi, merkle_indexe def chain_hash(input, n: Const, output): - debug_assert(2 < n) + debug_assert(2 <= n) states = Array((n-1) * DIGEST_LEN) poseidon16(input, ZERO_VEC_PTR, states) for i in unroll(1, n-1): @@ -165,7 +165,7 @@ def copy_8(x, y): @inline def copy_7(x, y): dot_product(x, ONE_VEC_PTR, y, 1, EE) - x[5] = y[5] + y[5] = x[5] y[6] = x[6] return @@ -173,5 +173,11 @@ def copy_7(x, y): @inline def copy_6(x, y): dot_product(x, ONE_VEC_PTR, y, 1, EE) - x[5] = y[5] + y[5] = x[5] return + + +def print_digest(digest): + for i in unroll(0, DIGEST_LEN): + print(digest[i]) + return \ No newline at end of file From bf5a738e76277550ba3910708569bd3e916885d3 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 14:40:14 +0400 Subject: [PATCH 10/29] wip --- crates/rec_aggregation/src/xmss_aggregate.rs | 3 ++- crates/rec_aggregation/xmss_aggregate.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 3b180c92..30ed8905 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -8,7 +8,7 @@ use std::time::Instant; use std::{collections::BTreeMap, path::Path}; use tracing::{info_span, instrument}; use xmss::{ - LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, SIG_SIZE_FE, TARGET_SUM, V, V_GRINDING, W, XmssPublicKey, XmssSignature, slot_to_field_elements, xmss_verify_with_poseidon_trace + LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, RANDOMNESS_LEN_FE, SIG_SIZE_FE, TARGET_SUM, V, V_GRINDING, W, XmssPublicKey, XmssSignature, slot_to_field_elements, xmss_verify_with_poseidon_trace }; static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); @@ -76,6 +76,7 @@ fn compile_xmss_aggregation_program() -> Bytecode { replacements.insert("TARGET_SUM_PLACEHOLDER".to_string(), TARGET_SUM.to_string()); replacements.insert("LOG_LIFETIME_PLACEHOLDER".to_string(), LOG_LIFETIME.to_string()); replacements.insert("MESSAGE_LEN_PLACEHOLDER".to_string(), MESSAGE_LEN_FE.to_string()); + replacements.insert("RANDOMNESS_LEN_PLACEHOLDER".to_string(), RANDOMNESS_LEN_FE.to_string()); compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements }) } diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 467ae917..e81ef532 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -8,7 +8,7 @@ LOG_LIFETIME = LOG_LIFETIME_PLACEHOLDER MESSAGE_LEN = MESSAGE_LEN_PLACEHOLDER DIGEST_LEN = 8 -RANDOMNESS_LEN = 8 +RANDOMNESS_LEN = RANDOMNESS_LEN_PLACEHOLDER SIG_SIZE = RANDOMNESS_LEN + (V + LOG_LIFETIME) * DIGEST_LEN NUM_ENCODING_FE = div_ceil((V + V_GRINDING), (24 / W)) # 24 should be divisible by W (works for W=2,3,4) From 2d9eea6a121f83b5c372002e5977002dcb503054 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 14:48:28 +0400 Subject: [PATCH 11/29] w --- crates/xmss/src/wots.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index b2cb667f..99bac71c 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -92,7 +92,8 @@ impl WotsPublicKey { } pub fn hash_with_poseidon_trace(&self, poseidon_16_trace: &mut Poseidon16History) -> Digest { - self.0.iter().fold(Digest::default(), |digest, chunk| { + let init = poseidon16_compress_with_trace(&self.0[0], &self.0[1], poseidon_16_trace); + self.0[2..].iter().fold(init, |digest, chunk| { poseidon16_compress_with_trace(&digest, chunk, poseidon_16_trace) }) } From 27a79335dbf2d443a0e27f1493e54d32e42a493a Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 14:52:52 +0400 Subject: [PATCH 12/29] w --- crates/rec_aggregation/xmss_aggregate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index e81ef532..9e4a168c 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -152,7 +152,7 @@ def slice_hash(data, len: Const): poseidon16(data, data + DIGEST_LEN, states) for i in unroll(1, len-1): poseidon16(states + (i - 1) * DIGEST_LEN, data + DIGEST_LEN * (i + 1), states + i * DIGEST_LEN) - return data + return states + (len - 2) * DIGEST_LEN @inline From 9bc163bb7fd077b7fc52b05dd89300e4116d0d49 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 15:26:26 +0400 Subject: [PATCH 13/29] fix --- crates/rec_aggregation/src/xmss_aggregate.rs | 6 ++++-- crates/rec_aggregation/xmss_aggregate.py | 9 +++++---- crates/xmss/src/lib.rs | 2 +- crates/xmss/src/xmss.rs | 4 ++-- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 30ed8905..a7d16ada 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -8,7 +8,8 @@ use std::time::Instant; use std::{collections::BTreeMap, path::Path}; use tracing::{info_span, instrument}; use xmss::{ - LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, RANDOMNESS_LEN_FE, SIG_SIZE_FE, TARGET_SUM, V, V_GRINDING, W, XmssPublicKey, XmssSignature, slot_to_field_elements, xmss_verify_with_poseidon_trace + LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, RANDOMNESS_LEN_FE, SIG_SIZE_FE, TARGET_SUM, V, V_GRINDING, W, + XmssPublicKey, XmssSignature, slot_to_field_elements, xmss_verify_with_poseidon_trace, }; static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); @@ -102,6 +103,7 @@ pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { let end = slot + rng.random_range(1..5); let (sk, pk) = xmss::xmss_key_gen(rng.random(), start, end).unwrap(); let sig = xmss::xmss_sign(&mut rng, &sk, &message, slot).unwrap(); + xmss::xmss_verify(&pk, &message, &sig).unwrap(); // sanity check (pk, sig) }) .collect::>(); @@ -203,5 +205,5 @@ fn precompute_poseidons( #[test] fn test_xmss_aggregate() { - run_xmss_benchmark(1, false); + run_xmss_benchmark(5, false); } diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 9e4a168c..a4aab368 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -121,9 +121,10 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): # First merkle round match leaf_position_bits[0]: case 0: - poseidon16(leaf_digest, merkle_path, states) - case 1: poseidon16(merkle_path, leaf_digest, states) + case 1: + poseidon16(leaf_digest, merkle_path, states) + # Remaining merkle rounds state_indexes = Array(height) @@ -134,14 +135,14 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): match leaf_position_bits[j]: case 0: poseidon16( - state_indexes[j - 1], merkle_path + j * DIGEST_LEN, + state_indexes[j - 1], state_indexes[j], ) case 1: poseidon16( - merkle_path + j * DIGEST_LEN, state_indexes[j - 1], + merkle_path + j * DIGEST_LEN, state_indexes[j], ) return state_indexes[height - 1] diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 3cedc895..c53f788a 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -36,5 +36,5 @@ fn poseidon16_compress_with_trace(a: &Digest, b: &Digest, poseidon_16_trace: &mu let input: [F; 16] = [*a, *b].concat().try_into().unwrap(); let output = poseidon16_compress(input); poseidon_16_trace.push((input, output)); - output[0..DIGEST_SIZE].try_into().unwrap() + output } diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 07cb86a2..3008f38c 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -123,7 +123,7 @@ pub fn xmss_sign( } let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot as u64); let merkle_root = secret_key.public_key().merkle_root; - let truncated_merkle_root: [F; TRUNCATED_MERKLE_ROOT_LEN_FE] = merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); + let truncated_merkle_root = merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); let wots_signature = wots_secret_key.sign(message_hash, slot, &truncated_merkle_root, rng); let merkle_proof = (0..LOG_LIFETIME) .map(|level| { @@ -172,7 +172,7 @@ pub fn xmss_verify_with_poseidon_trace( signature: &XmssSignature, ) -> Result { let mut poseidon_16_trace = Vec::new(); - let truncated_merkle_root: [F; TRUNCATED_MERKLE_ROOT_LEN_FE] = pub_key.merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); + let truncated_merkle_root = pub_key.merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); let wots_public_key = signature .wots_signature .recover_public_key_with_poseidon_trace( From ce4decd27b688e4220ecfaa36eb7d874bb31eeee Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 15:33:29 +0400 Subject: [PATCH 14/29] compiles --- src/lib.rs | 48 ++++++++++++++++++++++++++---------------------- src/main.rs | 4 +--- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 78ee4d3c..91de27c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,17 +4,7 @@ pub use multilinear_toolkit::prelude::{ ProofError, }; pub use rec_aggregation::xmss_aggregate::{xmss_aggregate_signatures, xmss_verify_aggregated_signatures}; -pub use xmss::{ - XMSS_MAX_LOG_LIFETIME, - XMSS_MIN_LOG_LIFETIME, - XmssPublicKey, - XmssSecretKey, - XmssSignature, - xmss_generate_phony_signatures, // useful for tests - xmss_key_gen, - xmss_sign, - xmss_verify, -}; +pub use xmss::{XmssPublicKey, XmssSecretKey, XmssSignature, xmss_key_gen, xmss_sign, xmss_verify}; pub fn xmss_aggregation_setup_prover() { rec_aggregation::xmss_aggregate::xmss_setup_aggregation_program(); @@ -34,19 +24,21 @@ WARNING: Toy XMSS, do not consider this secure (for now)! #[cfg(test)] mod tests { use super::*; + use rand::{Rng, SeedableRng, rngs::StdRng}; + use xmss::MESSAGE_LEN_FE; #[test] fn test_xmss_signature() { - let first_slot = 555; - let log_lifetime = 10; - let slot = first_slot + 10; + let start = 555; + let end = 565; + let slot = 560; let key_gen_seed: [u8; 32] = std::array::from_fn(|i| i as u8); - let randomness_seed: [u8; 32] = std::array::from_fn(|i| (i * 2) as u8); - let message_hash: [F; 8] = std::array::from_fn(|i| F::from_usize(i * 3)); + let message_hash: [F; MESSAGE_LEN_FE] = std::array::from_fn(|i| F::from_usize(i * 3)); - let (secret_key, pub_key) = xmss_key_gen(key_gen_seed, first_slot, log_lifetime).unwrap(); - let signature = xmss_sign(randomness_seed, &secret_key, &message_hash, slot).unwrap(); - xmss_verify(&pub_key, &message_hash, &signature, slot).unwrap(); + let (secret_key, pub_key) = xmss_key_gen(key_gen_seed, start, end).unwrap(); + let mut rng = StdRng::from_seed(std::array::from_fn(|i| (i * 2) as u8)); + let signature = xmss_sign(&mut rng, &secret_key, &message_hash, slot).unwrap(); + xmss_verify(&pub_key, &message_hash, &signature).unwrap(); } #[test] @@ -58,11 +50,23 @@ mod tests { // (Actually, no need to call it if `xmss_aggregation_setup_prover` was already called) xmss_aggregation_setup_verifier(); - let log_lifetimes = (XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME).collect::>(); - let message_hash: [F; 8] = std::array::from_fn(|i| F::from_usize(i * 7)); + let n_signatures = 3; + let message_hash: [F; MESSAGE_LEN_FE] = std::array::from_fn(|i| F::from_usize(i * 7)); let slot = 77777; - let (xmss_pub_keys, all_signatures) = xmss_generate_phony_signatures(&log_lifetimes, message_hash, slot); + let pub_keys_and_sigs: Vec<_> = (0..n_signatures) + .into_par_iter() + .map(|i| { + let mut rng = StdRng::seed_from_u64(i as u64); + let start = slot - 5; + let end = slot + 5; + let (sk, pk) = xmss_key_gen(rng.random(), start, end).unwrap(); + let sig = xmss_sign(&mut rng, &sk, &message_hash, slot).unwrap(); + xmss_verify(&pk, &message_hash, &sig).unwrap(); + (pk, sig) + }) + .collect(); + let (xmss_pub_keys, all_signatures): (Vec<_>, Vec<_>) = pub_keys_and_sigs.into_iter().unzip(); let proof = xmss_aggregate_signatures(&xmss_pub_keys, &all_signatures, message_hash, slot).unwrap(); xmss_verify_aggregated_signatures(&xmss_pub_keys, message_hash, &proof, slot).unwrap(); diff --git a/src/main.rs b/src/main.rs index 2781712d..0dc01694 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ use clap::Parser; mod prove_poseidons; use rec_aggregation::{recursion::run_recursion_benchmark, xmss_aggregate::run_xmss_benchmark}; -use xmss::XMSS_MAX_LOG_LIFETIME; use crate::prove_poseidons::benchmark_prove_poseidon_16; @@ -35,8 +34,7 @@ fn main() { match cli { Cli::Xmss { n_signatures, tracing } => { - let log_lifetimes = (0..n_signatures).map(|_| XMSS_MAX_LOG_LIFETIME).collect::>(); - run_xmss_benchmark(&log_lifetimes, tracing); + run_xmss_benchmark(n_signatures, tracing); } Cli::Recursion { n, tracing } => { run_recursion_benchmark(n, tracing); From 0bac9730755fc26792cea132843aded331a1965a Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 15:39:28 +0400 Subject: [PATCH 15/29] wip --- crates/rec_aggregation/src/xmss_aggregate.rs | 4 ++-- crates/xmss/src/lib.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index a7d16ada..12fadd7d 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -99,8 +99,8 @@ pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { .into_par_iter() .map(|i| { let mut rng = StdRng::seed_from_u64(i as u64); - let start = slot - rng.random_range(0..5); - let end = slot + rng.random_range(1..5); + let start = slot - rng.random_range(0..3); + let end = slot + rng.random_range(1..3); let (sk, pk) = xmss::xmss_key_gen(rng.random(), start, end).unwrap(); let sig = xmss::xmss_sign(&mut rng, &sk, &message, slot).unwrap(); xmss::xmss_verify(&pk, &message, &sig).unwrap(); // sanity check diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index c53f788a..347c47aa 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -17,12 +17,12 @@ type F = KoalaBear; type Digest = [F; DIGEST_SIZE]; // WOTS -pub const V: usize = 66; +pub const V: usize = 65; pub const W: usize = 2; pub const CHAIN_LENGTH: usize = 1 << W; pub const NUM_CHAIN_HASHES: usize = 80; pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; -pub const V_GRINDING: usize = 1; +pub const V_GRINDING: usize = 4; pub const LOG_LIFETIME: usize = 32; pub const RANDOMNESS_LEN_FE: usize = 7; pub const MESSAGE_LEN_FE: usize = 9; From 11d1a76ed693c1392c231b8dd8daa4aa2e7a67e9 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 15:48:41 +0400 Subject: [PATCH 16/29] fix --- crates/lean_vm/src/isa/hint.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index ea473cbb..29512ed8 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -110,7 +110,7 @@ impl CustomHint { for i in 0..num_to_decompose { let value = ctx.memory.get(to_decompose_ptr + i)?.to_usize(); for i in 0..24 / w { - let value = F::from_usize((value >> (w * i)) & 0b11); + let value = F::from_usize((value >> (w * i)) & ((1 << w) - 1)); ctx.memory.set(memory_index_decomposed, value)?; memory_index_decomposed += 1; } From 3c0caed740986dccfc0beabe182420c0eceb088c Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 15:53:30 +0400 Subject: [PATCH 17/29] works --- crates/rec_aggregation/src/xmss_aggregate.rs | 11 +++++++++-- crates/xmss/src/lib.rs | 8 ++++---- crates/xmss/src/wots.rs | 2 +- crates/xmss/src/xmss.rs | 4 ++-- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 12fadd7d..cf2476d5 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -95,7 +95,8 @@ pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { .unwrap(); let slot = 1111; - let pub_keys_and_sigs = (0..n_signatures) + let n_reps = 10; + let pub_keys_and_sigs = (0..n_signatures / n_reps) .into_par_iter() .map(|i| { let mut rng = StdRng::seed_from_u64(i as u64); @@ -107,7 +108,13 @@ pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { (pk, sig) }) .collect::>(); - let (xmss_pub_keys, all_signatures): (Vec<_>, Vec<_>) = pub_keys_and_sigs.into_iter().unzip(); + + let (mut xmss_pub_keys, mut all_signatures) = (vec![], vec![]); + for i in 0..n_signatures { + let (pk, sig) = &pub_keys_and_sigs[i % (n_signatures / n_reps)]; + xmss_pub_keys.push(pk.clone()); + all_signatures.push(sig.clone()); + } let time = Instant::now(); let (proof_data, n_field_elements_in_proof, summary) = xmss_aggregate_signatures_helper(&xmss_pub_keys, &all_signatures, message, slot).unwrap(); diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 347c47aa..964eae6d 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -17,12 +17,12 @@ type F = KoalaBear; type Digest = [F; DIGEST_SIZE]; // WOTS -pub const V: usize = 65; -pub const W: usize = 2; +pub const V: usize = 40; +pub const W: usize = 3; pub const CHAIN_LENGTH: usize = 1 << W; -pub const NUM_CHAIN_HASHES: usize = 80; +pub const NUM_CHAIN_HASHES: usize = 120; pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; -pub const V_GRINDING: usize = 4; +pub const V_GRINDING: usize = 3; pub const LOG_LIFETIME: usize = 32; pub const RANDOMNESS_LEN_FE: usize = 7; pub const MESSAGE_LEN_FE: usize = 9; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 99bac71c..003a1d6e 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -13,7 +13,7 @@ pub struct WotsSecretKey { #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct WotsPublicKey(pub [Digest; V]); -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct WotsSignature { pub chain_tips: [Digest; V], pub randomness: [F; RANDOMNESS_LEN_FE], diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 3008f38c..a42c6c1b 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -14,14 +14,14 @@ pub struct XmssSecretKey { pub(crate) merkle_tree: Vec>, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct XmssSignature { pub wots_signature: WotsSignature, pub slot: u32, pub merkle_proof: Vec, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct XmssPublicKey { pub merkle_root: Digest, } From 4692f1ee75544463595282bfc00a66df857644f7 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 21:20:00 +0400 Subject: [PATCH 18/29] wip --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6271a1df..f9323edc 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Machine: M4 Max 48GB (CPU only) To reproduce: - `cargo run --release -- poseidon --log-n-perms 20` - `cargo run --release -- recursion --n 2` -- `cargo run --release -- xmss --n-signatures 1350` +- `cargo run --release -- xmss --n-signatures 1300` ## Proof size From 0eb56705e6d99e5a65cdcc8a9456d48a21103051 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 21:28:21 +0400 Subject: [PATCH 19/29] wip --- crates/rec_aggregation/xmss_aggregate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index a4aab368..0291fc86 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -148,7 +148,8 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): return state_indexes[height - 1] -def slice_hash(data, len: Const): +@inline +def slice_hash(data, len): states = Array((len-1) * DIGEST_LEN) poseidon16(data, data + DIGEST_LEN, states) for i in unroll(1, len-1): @@ -166,8 +167,7 @@ def copy_8(x, y): @inline def copy_7(x, y): dot_product(x, ONE_VEC_PTR, y, 1, EE) - y[5] = x[5] - y[6] = x[6] + dot_product(x + (7-DIM), ONE_VEC_PTR, y + (7-DIM), 1, EE) return From 70466cda7cda245a17c325cad6ef11d146828a5d Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 21:33:32 +0400 Subject: [PATCH 20/29] w --- crates/rec_aggregation/src/xmss_aggregate.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index cf2476d5..71076b3f 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -96,7 +96,7 @@ pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { let slot = 1111; let n_reps = 10; - let pub_keys_and_sigs = (0..n_signatures / n_reps) + let pub_keys_and_sigs = (0..n_signatures.div_ceil(n_reps)) .into_par_iter() .map(|i| { let mut rng = StdRng::seed_from_u64(i as u64); @@ -108,10 +108,10 @@ pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { (pk, sig) }) .collect::>(); - + let (mut xmss_pub_keys, mut all_signatures) = (vec![], vec![]); for i in 0..n_signatures { - let (pk, sig) = &pub_keys_and_sigs[i % (n_signatures / n_reps)]; + let (pk, sig) = &pub_keys_and_sigs[i % (n_signatures.div_ceil(n_reps))]; xmss_pub_keys.push(pk.clone()); all_signatures.push(sig.clone()); } From 8f4b63f9a56ed8229e7fc77d1f52992d57e56df2 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 21:41:03 +0400 Subject: [PATCH 21/29] opti --- README.md | 2 +- crates/rec_aggregation/xmss_aggregate.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f9323edc..6271a1df 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Machine: M4 Max 48GB (CPU only) To reproduce: - `cargo run --release -- poseidon --log-n-perms 20` - `cargo run --release -- recursion --n 2` -- `cargo run --release -- xmss --n-signatures 1300` +- `cargo run --release -- xmss --n-signatures 1350` ## Proof size diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 0291fc86..691ee253 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -108,9 +108,12 @@ def chain_hash(input, n: Const, output): debug_assert(2 <= n) states = Array((n-1) * DIGEST_LEN) poseidon16(input, ZERO_VEC_PTR, states) + state_indexes = Array(n - 1) + state_indexes[0] = states for i in unroll(1, n-1): - poseidon16(states + (i - 1) * DIGEST_LEN, ZERO_VEC_PTR, states + i * DIGEST_LEN) - poseidon16(states + (n - 2) * DIGEST_LEN, ZERO_VEC_PTR, output) + state_indexes[i] = state_indexes[i - 1] + DIGEST_LEN + poseidon16(state_indexes[i - 1], ZERO_VEC_PTR, state_indexes[i]) + poseidon16(state_indexes[n - 2], ZERO_VEC_PTR, output) return @@ -131,7 +134,7 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): state_indexes[0] = states for j in unroll(1, height): state_indexes[j] = state_indexes[j - 1] + DIGEST_LEN - # Warning: this works only if leaf_position_bits[i] is known to be boolean: + # Warning: this works only if leaf_position_bits[j] is known to be boolean: match leaf_position_bits[j]: case 0: poseidon16( @@ -152,9 +155,12 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): def slice_hash(data, len): states = Array((len-1) * DIGEST_LEN) poseidon16(data, data + DIGEST_LEN, states) + state_indexes = Array(len - 1) + state_indexes[0] = states for i in unroll(1, len-1): - poseidon16(states + (i - 1) * DIGEST_LEN, data + DIGEST_LEN * (i + 1), states + i * DIGEST_LEN) - return states + (len - 2) * DIGEST_LEN + state_indexes[i] = state_indexes[i - 1] + DIGEST_LEN + poseidon16(state_indexes[i - 1], data + DIGEST_LEN * (i + 1), state_indexes[i]) + return state_indexes[len - 2] @inline From 92e2d1cdedc0b12080f6ac9cebdde11aa995acfe Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 21:49:55 +0400 Subject: [PATCH 22/29] opti --- crates/rec_aggregation/xmss_aggregate.py | 31 +++++++++++++++++------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 691ee253..2c69f474 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -99,8 +99,7 @@ def xmss_verify(merkle_root, message, signature, slot_lo, slot_hi, merkle_indexe range(2, CHAIN_LENGTH), lambda num_hashes_const: chain_hash(chain_tips + i * DIGEST_LEN, num_hashes_const, wots_public_key + i * DIGEST_LEN)) wots_pubkey_hashed = slice_hash(wots_public_key, V) - merkle_root_recovered = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_indexes, LOG_LIFETIME) - copy_8(merkle_root, merkle_root_recovered) + merkle_verify(wots_pubkey_hashed, merkle_path, merkle_indexes, LOG_LIFETIME, merkle_root) return @@ -118,8 +117,8 @@ def chain_hash(input, n: Const, output): @inline -def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): - states = Array(height * DIGEST_LEN) +def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height, expected_root): + states = Array((height - 1) * DIGEST_LEN) # First merkle round match leaf_position_bits[0]: @@ -128,11 +127,10 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): case 1: poseidon16(leaf_digest, merkle_path, states) - - # Remaining merkle rounds - state_indexes = Array(height) + # Middle merkle rounds + state_indexes = Array(height - 1) state_indexes[0] = states - for j in unroll(1, height): + for j in unroll(1, height - 1): state_indexes[j] = state_indexes[j - 1] + DIGEST_LEN # Warning: this works only if leaf_position_bits[j] is known to be boolean: match leaf_position_bits[j]: @@ -148,7 +146,22 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height): merkle_path + j * DIGEST_LEN, state_indexes[j], ) - return state_indexes[height - 1] + + # Final merkle round + match leaf_position_bits[height - 1]: + case 0: + poseidon16( + merkle_path + (height - 1) * DIGEST_LEN, + state_indexes[height - 2], + expected_root, + ) + case 1: + poseidon16( + state_indexes[height - 2], + merkle_path + (height - 1) * DIGEST_LEN, + expected_root, + ) + return @inline From db3bd66af78c68f5446427cd6ae4ece619ae2bfe Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 21:50:06 +0400 Subject: [PATCH 23/29] naming --- crates/rec_aggregation/xmss_aggregate.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 691ee253..39d8d6bb 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -40,8 +40,8 @@ def xmss_verify(merkle_root, message, signature, slot_lo, slot_hi, merkle_indexe # signature: randomness | chain_tips | merkle_path # return the hashed xmss public key randomness = signature - chain_tips = signature + RANDOMNESS_LEN - merkle_path = chain_tips + V * DIGEST_LEN + chain_starts = signature + RANDOMNESS_LEN + merkle_path = chain_starts + V * DIGEST_LEN # 1) We encode message_hash + randomness into the layer of the hypercube with target sum = TARGET_SUM @@ -93,10 +93,12 @@ def xmss_verify(merkle_root, message, signature, slot_lo, slot_hi, merkle_indexe for i in unroll(0, V): num_hashes = (CHAIN_LENGTH - 1) - encoding[i] + chain_start = chain_starts + i * DIGEST_LEN + chain_end = wots_public_key + i * DIGEST_LEN match_range(num_hashes, - range(0, 1), lambda _: copy_8(chain_tips + i * DIGEST_LEN, wots_public_key + i * DIGEST_LEN), - range(1, 2), lambda _: poseidon16(chain_tips + i * DIGEST_LEN, ZERO_VEC_PTR, wots_public_key + i * DIGEST_LEN), - range(2, CHAIN_LENGTH), lambda num_hashes_const: chain_hash(chain_tips + i * DIGEST_LEN, num_hashes_const, wots_public_key + i * DIGEST_LEN)) + range(0, 1), lambda _: copy_8(chain_start, chain_end), + range(1, 2), lambda _: poseidon16(chain_start, ZERO_VEC_PTR, chain_end), + range(2, CHAIN_LENGTH), lambda num_hashes_const: chain_hash(chain_start, num_hashes_const, chain_end)) wots_pubkey_hashed = slice_hash(wots_public_key, V) merkle_root_recovered = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_indexes, LOG_LIFETIME) From dc2237f60dc40d7db2fea9e31ca11c4afc418e57 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 22:09:33 +0400 Subject: [PATCH 24/29] colors --- Cargo.lock | 8 ++++++++ Cargo.toml | 2 ++ crates/lean_prover/Cargo.toml | 1 + crates/lean_prover/src/prove_execution.rs | 10 ++++++---- crates/sub_protocols/Cargo.toml | 1 + crates/sub_protocols/src/generic_logup.rs | 12 +++++++++++- crates/sub_protocols/src/packed_pcs.rs | 12 +++++++++++- 7 files changed, 40 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d4025b25..ef4ac00e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -401,6 +401,7 @@ dependencies = [ "lean_compiler", "lean_vm", "multilinear-toolkit", + "owo-colors", "p3-koala-bear", "p3-poseidon2", "p3-symmetric", @@ -553,6 +554,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "owo-colors" +version = "4.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c6901729fa79e91a0913333229e9ca5dc725089d1c363b2f4b4760709dc4a52" + [[package]] name = "p3-challenger" version = "0.3.0" @@ -1038,6 +1045,7 @@ version = "0.1.0" dependencies = [ "lean_vm", "multilinear-toolkit", + "owo-colors", "p3-koala-bear", "p3-util", "rand", diff --git a/Cargo.toml b/Cargo.toml index 09382385..d96c6ec0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ strum = { version = "0.27.2", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] } bincode = "1.3.3" num_enum = "0.7.5" +owo-colors = "4.2.3" tracing-subscriber = { version = "0.3.19", features = ["std", "env-filter"] } tracing-forest = { version = "0.3.0", features = ["ansi", "smallvec"] } p3-koala-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } @@ -88,5 +89,6 @@ p3-koala-bear.workspace = true lean_vm.workspace = true multilinear-toolkit.workspace = true + [profile.release] lto = "thin" diff --git a/crates/lean_prover/Cargo.toml b/crates/lean_prover/Cargo.toml index 48ec0a47..ef2ee111 100644 --- a/crates/lean_prover/Cargo.toml +++ b/crates/lean_prover/Cargo.toml @@ -17,6 +17,7 @@ p3-koala-bear.workspace = true p3-symmetric.workspace = true p3-util.workspace = true tracing.workspace = true +owo-colors.workspace = true air.workspace = true sub_protocols.workspace = true lean_vm.workspace = true diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index b4e0902f..4124afc1 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -4,6 +4,7 @@ use crate::*; use air::prove_air; use lean_vm::*; +use owo_colors::OwoColorize; use sub_protocols::*; use tracing::info_span; use utils::build_prover_state; @@ -62,13 +63,14 @@ pub fn prove_execution( let mut table_log = String::new(); for (table, trace) in &traces { table_log.push_str(&format!( - "{}: 2^{:.2} rows |", + "{}: 2^{} * (1 + {:.2}) rows | ", table.name(), - f64::log2(trace.non_padded_n_rows as f64) + trace.log_n_rows - 1, + (trace.non_padded_n_rows as f64) / (1 << (trace.log_n_rows - 1)) as f64 - 1.0 )); } - table_log.pop(); // remove last '|' - info_span!("Trace tables sizes: {}", table_log).in_scope(|| {}); + table_log = table_log.trim_end_matches(" | ").to_string(); + tracing::info!("Trace tables sizes: {}", table_log.magenta()); // TODO parrallelize let mut memory_acc = F::zero_vec(memory.len()); diff --git a/crates/sub_protocols/Cargo.toml b/crates/sub_protocols/Cargo.toml index bc353456..c0b42c50 100644 --- a/crates/sub_protocols/Cargo.toml +++ b/crates/sub_protocols/Cargo.toml @@ -10,6 +10,7 @@ workspace = true tracing.workspace = true utils.workspace = true p3-util.workspace = true +owo-colors.workspace = true lean_vm.workspace = true multilinear-toolkit.workspace = true diff --git a/crates/sub_protocols/src/generic_logup.rs b/crates/sub_protocols/src/generic_logup.rs index 84eb3ce0..3a3fa3b9 100644 --- a/crates/sub_protocols/src/generic_logup.rs +++ b/crates/sub_protocols/src/generic_logup.rs @@ -1,6 +1,7 @@ use crate::{prove_gkr_quotient, verify_gkr_quotient}; use lean_vm::*; use multilinear_toolkit::prelude::*; +use owo_colors::OwoColorize; use std::collections::BTreeMap; use tracing::instrument; use utils::*; @@ -201,7 +202,16 @@ pub fn prove_generic_logup( } assert_eq!(log2_ceil_usize(offset), total_gkr_n_vars); - tracing::info!("Logup data: {} = 2^{:.2}", offset, (offset as f64).log2()); + tracing::info!( + "{}", + format!( + "Logup data: {} = 2^{} * (1 + {:.2})", + offset, + total_gkr_n_vars - 1, + (offset as f64) / (1 << (total_gkr_n_vars - 1)) as f64 - 1.0 + ) + .blue() + ); denominators[offset..].par_iter_mut().for_each(|d| *d = EF::ONE); // padding diff --git a/crates/sub_protocols/src/packed_pcs.rs b/crates/sub_protocols/src/packed_pcs.rs index 80e6af87..4f57f808 100644 --- a/crates/sub_protocols/src/packed_pcs.rs +++ b/crates/sub_protocols/src/packed_pcs.rs @@ -1,6 +1,7 @@ use lean_vm::{COL_PC, CommittedStatements, ENDING_PC, STARTING_PC, sort_tables_by_height}; use lean_vm::{EF, F, Table, TableT, TableTrace}; use multilinear_toolkit::prelude::*; +use owo_colors::OwoColorize; use p3_util::log2_ceil_usize; use std::collections::BTreeMap; use tracing::instrument; @@ -106,7 +107,16 @@ pub fn packed_pcs_commit( } } assert_eq!(log2_ceil_usize(offset), packed_n_vars); - tracing::info!("packed PCS data: {} = 2^{:.2}", offset, (offset as f64).log2()); + tracing::info!( + "{}", + format!( + "packed PCS data: {} = 2^{} * (1 + {:.2})", + offset, + packed_n_vars - 1, + (offset as f64) / (1 << (packed_n_vars - 1)) as f64 - 1.0 + ) + .green() + ); let packed_polynomial = MleOwned::Base(packed_polynomial); From 20880293085d0ffbb5b9031faac8c82f3a5944bd Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 23:27:38 +0400 Subject: [PATCH 25/29] readme --- README.md | 40 +++++++++------ crates/lean_prover/src/lib.rs | 11 +++-- crates/lean_prover/src/test_zkvm.rs | 6 +-- crates/rec_aggregation/src/recursion.rs | 12 ++--- crates/rec_aggregation/src/xmss_aggregate.rs | 52 ++++++++++++++++---- src/bench.sh | 27 ++++++++++ src/lib.rs | 22 ++++++++- src/main.rs | 28 +++++++++-- 8 files changed, 153 insertions(+), 45 deletions(-) create mode 100755 src/bench.sh diff --git a/README.md b/README.md index 6271a1df..bfe5ab03 100644 --- a/README.md +++ b/README.md @@ -20,30 +20,40 @@ The VM design is inspired by the famous [Cairo paper](https://eprint.iacr.org/20 ## Security -123 bits of security. Johnson bound + degree 5 extension of koala-bear -> **no proximity gaps conjecture**. (TODO 128 bits? this would require hash digests bigger than 8 koala-bears). +123 bits of provable security, given by Johnson bound + degree 5 extension of koala-bear. (128 bits would require hash digests of more than 8 field elements, todo?). In the benchmarks, we also display performance with conjectured security, even though leanVM targets the proven regime by default. -## Benchmarks (Slightly outdated, new benchmarks incoming) +## Benchmarks Machine: M4 Max 48GB (CPU only) -| Benchmark | Current | Target | -| -------------------------- | -------------------- | --------------- | -| Poseidon2 (16 koala-bears) | `560K Poseidon2 / s` | n/a | -| 2 -> 1 Recursion | `1.15 s` | `0.25 s ` | -| XMSS aggregation | `554 XMSS / s` | `1000 XMSS / s` | - *Expect incoming perf improvements.* -To reproduce: -- `cargo run --release -- poseidon --log-n-perms 20` -- `cargo run --release -- recursion --n 2` -- `cargo run --release -- xmss --n-signatures 1350` +### XMSS aggregation + +``` +cargo run --release -- xmss --n-signatures 1350 +``` + +| WHIR rate \ regime | Proven | Conjectured | +| ------------------ | -------------------- | -------------------- | +| 1/2 | 530 XMSS/s - 383 KiB | 530 XMSS/s - 209 KiB | +| 1/4 | 420 XMSS/s - 252 KiB | 420 XMSS/s - 148 KiB | + +(Proving throughput - proof size) + +### Recursion + +``` +cargo run --release -- recursion --n 2 +``` -## Proof size +2 to 1 recursion (WHIR rate = 1/4): -WHIR intial rate = 1/4 -> proof size ≈ 225 KiB. (150 KiB with rate 1/16, and < 100 KiB is possible with poximity gaps conjecture + rate 1/16). +| Proven | Conjectured | +| --------------- | --------------- | +| 1.10s - 223 KiB | 1.05s - 134 KiB | -(TODO: remaining optimization = [2024/108](https://eprint.iacr.org/2024/108.pdf) section 3.1) +(Note about proof size: remaining optimization = [2024/108](https://eprint.iacr.org/2024/108.pdf) section 3.1) ## Credits diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index b235a92e..e0cb450b 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -18,15 +18,16 @@ use trace_gen::*; // so ≈ 123.92 bits of security against collisions pub const SECURITY_BITS: usize = 123; // TODO 128 bits security? (with Poseidon over 20 field elements or with a more subtle soundness analysis (cf. https://eprint.iacr.org/2021/188.pdf)) -// Provable security (no proximity gaps conjectures) -pub const SECURITY_REGIME: SecurityAssumption = SecurityAssumption::JohnsonBound; - pub const GRINDING_BITS: usize = 18; -pub fn default_whir_config(starting_log_inv_rate: usize) -> WhirConfigBuilder { +pub fn default_whir_config(starting_log_inv_rate: usize, prox_gaps_conjecture: bool) -> WhirConfigBuilder { WhirConfigBuilder { folding_factor: FoldingFactor::new(7, 5), - soundness_type: SECURITY_REGIME, + soundness_type: if prox_gaps_conjecture { + SecurityAssumption::CapacityBound // TODO update formula with State of Art Conjecture + } else { + SecurityAssumption::JohnsonBound + }, pow_bits: GRINDING_BITS, max_num_variables_to_send_coeffs: 9, rs_domain_initial_reduction_factor: 5, diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 0078dd27..bd2e25ef 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -154,7 +154,7 @@ fn test_zk_vm_helper(program_str: &str, (public_input, private_input): (&[F], &[ &bytecode, (public_input, private_input), &vec![], - &default_whir_config(starting_log_inv_rate), + &default_whir_config(starting_log_inv_rate, false), false, ); let proof_time = time.elapsed(); @@ -162,7 +162,7 @@ fn test_zk_vm_helper(program_str: &str, (public_input, private_input): (&[F], &[ &bytecode, public_input, proof.proof.clone(), - &default_whir_config(starting_log_inv_rate), + &default_whir_config(starting_log_inv_rate, false), ) .unwrap(); println!("{}", proof.exec_summary); @@ -183,7 +183,7 @@ fn test_zk_vm_helper(program_str: &str, (public_input, private_input): (&[F], &[ &bytecode, public_input, fuzzed_proof, - &default_whir_config(starting_log_inv_rate), + &default_whir_config(starting_log_inv_rate, false), ); assert!(verify_result.is_err(), "Fuzzing failed at index {}", i); } diff --git a/crates/rec_aggregation/src/recursion.rs b/crates/rec_aggregation/src/recursion.rs index 5fe842ae..5e7a68bc 100644 --- a/crates/rec_aggregation/src/recursion.rs +++ b/crates/rec_aggregation/src/recursion.rs @@ -14,16 +14,14 @@ use multilinear_toolkit::prelude::symbolic::{ use multilinear_toolkit::prelude::*; use utils::{BYTECODE_TABLE_INDEX, Counter, MEMORY_TABLE_INDEX}; -const LOG_INV_RATE: usize = 2; - -pub fn run_recursion_benchmark(count: usize, tracing: bool) { +pub fn run_recursion_benchmark(count: usize, log_inv_rate: usize, prox_gaps_conjecture: bool, tracing: bool) { let filepath = Path::new(env!("CARGO_MANIFEST_DIR")) .join("recursion.py") .to_str() .unwrap() .to_string(); - let inner_whir_config = default_whir_config(LOG_INV_RATE); + let inner_whir_config = default_whir_config(log_inv_rate, prox_gaps_conjecture); let program_to_prove = r#" DIM = 5 POSEIDON_OF_ZERO = POSEIDON_OF_ZERO_PLACEHOLDER @@ -333,7 +331,7 @@ def main(): &recursion_bytecode, (&outer_public_input, &outer_private_input), &vec![], // TODO precompute poseidons - &default_whir_config(LOG_INV_RATE), + &default_whir_config(log_inv_rate, prox_gaps_conjecture), false, ); let proving_time = time.elapsed(); @@ -341,7 +339,7 @@ def main(): &recursion_bytecode, &outer_public_input, recursion_proof.proof, - &default_whir_config(LOG_INV_RATE), + &default_whir_config(log_inv_rate, prox_gaps_conjecture), ) .unwrap(); println!( @@ -589,5 +587,5 @@ fn display_all_air_evals_in_zk_dsl() { #[test] fn test_end2end_recursion() { - run_recursion_benchmark(1, false); + run_recursion_benchmark(1, 2, false, false); } diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 71076b3f..088f133f 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -13,7 +13,6 @@ use xmss::{ }; static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); -const LOG_INV_RATE: usize = 1; fn get_xmss_aggregation_program() -> &'static Bytecode { XMSS_AGGREGATION_PROGRAM.get_or_init(compile_xmss_aggregation_program) @@ -81,7 +80,7 @@ fn compile_xmss_aggregation_program() -> Bytecode { compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements }) } -pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { +pub fn run_xmss_benchmark(n_signatures: usize, log_inv_rate: usize, prox_gaps_conjecture: bool, tracing: bool) { if tracing { utils::init_tracing(); } @@ -116,11 +115,26 @@ pub fn run_xmss_benchmark(n_signatures: usize, tracing: bool) { all_signatures.push(sig.clone()); } let time = Instant::now(); - let (proof_data, n_field_elements_in_proof, summary) = - xmss_aggregate_signatures_helper(&xmss_pub_keys, &all_signatures, message, slot).unwrap(); + let (proof_data, n_field_elements_in_proof, summary) = xmss_aggregate_signatures_helper( + &xmss_pub_keys, + &all_signatures, + message, + slot, + log_inv_rate, + prox_gaps_conjecture, + ) + .unwrap(); let proving_time = time.elapsed(); - xmss_verify_aggregated_signatures(&xmss_pub_keys, message, &proof_data, slot).unwrap(); + xmss_verify_aggregated_signatures( + &xmss_pub_keys, + message, + &proof_data, + slot, + log_inv_rate, + prox_gaps_conjecture, + ) + .unwrap(); println!("{summary}"); println!( @@ -142,8 +156,18 @@ pub fn xmss_aggregate_signatures( all_signatures: &[XmssSignature], message: [F; MESSAGE_LEN_FE], slot: u32, + log_inv_rate: usize, + prox_gaps_conjecture: bool, ) -> Result, XmssAggregateError> { - Ok(xmss_aggregate_signatures_helper(xmss_pub_keys, all_signatures, message, slot)?.0) + Ok(xmss_aggregate_signatures_helper( + xmss_pub_keys, + all_signatures, + message, + slot, + log_inv_rate, + prox_gaps_conjecture, + )? + .0) } fn xmss_aggregate_signatures_helper( @@ -151,6 +175,8 @@ fn xmss_aggregate_signatures_helper( all_signatures: &[XmssSignature], message: [F; MESSAGE_LEN_FE], slot: u32, + log_inv_rate: usize, + prox_gaps_conjecture: bool, ) -> Result<(Vec, usize, String), XmssAggregateError> { if xmss_pub_keys.len() != all_signatures.len() { return Err(XmssAggregateError::WrongSignatureCount); @@ -168,7 +194,7 @@ fn xmss_aggregate_signatures_helper( program, (&public_input, &private_input), &poseidons_16_precomputed, - &default_whir_config(LOG_INV_RATE), + &default_whir_config(log_inv_rate, prox_gaps_conjecture), false, ); @@ -182,6 +208,8 @@ pub fn xmss_verify_aggregated_signatures( message: [F; MESSAGE_LEN_FE], proof_bytes: &[u8], slot: u32, + log_inv_rate: usize, + prox_gaps_conjecture: bool, ) -> Result<(), ProofError> { let program = get_xmss_aggregation_program(); @@ -191,7 +219,13 @@ pub fn xmss_verify_aggregated_signatures( let public_input = build_public_input(xmss_pub_keys, message, slot); - verify_execution(program, &public_input, proof, &default_whir_config(LOG_INV_RATE)).map(|_| ()) + verify_execution( + program, + &public_input, + proof, + &default_whir_config(log_inv_rate, prox_gaps_conjecture), + ) + .map(|_| ()) } #[instrument(skip_all)] @@ -212,5 +246,5 @@ fn precompute_poseidons( #[test] fn test_xmss_aggregate() { - run_xmss_benchmark(5, false); + run_xmss_benchmark(5, 1, false, false); } diff --git a/src/bench.sh b/src/bench.sh new file mode 100755 index 00000000..ff355300 --- /dev/null +++ b/src/bench.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +echo -e "\033[32m=== XMSS: 1350 sigs, log-inv-rate=1 ===\033[0m" +cargo run --release -- xmss --n-signatures 1350 --log-inv-rate 1 +sleep 1s + +echo -e "\033[32m=== XMSS: 1350 sigs, log-inv-rate=2 ===\033[0m" +cargo run --release -- xmss --n-signatures 1350 --log-inv-rate 2 +sleep 1s + +echo -e "\033[32m=== XMSS: 1350 sigs, log-inv-rate=1, prox-gaps-conjecture ===\033[0m" +cargo run --release -- xmss --n-signatures 1350 --log-inv-rate 1 --prox-gaps-conjecture +sleep 1s + +echo -e "\033[32m=== XMSS: 1350 sigs, log-inv-rate=2, prox-gaps-conjecture ===\033[0m" +cargo run --release -- xmss --n-signatures 1350 --log-inv-rate 2 --prox-gaps-conjecture +sleep 1s + +echo -e "\033[32m=== Recursion: n=2 ===\033[0m" +cargo run --release -- recursion --n 2 +sleep 1s + +echo -e "\033[32m=== Recursion: n=2, prox-gaps-conjecture ===\033[0m" +cargo run --release -- recursion --n 2 --prox-gaps-conjecture +sleep 1s + +echo -e "\033[32m=== Done ===\033[0m" diff --git a/src/lib.rs b/src/lib.rs index 91de27c1..ac0362d0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,7 +68,25 @@ mod tests { .collect(); let (xmss_pub_keys, all_signatures): (Vec<_>, Vec<_>) = pub_keys_and_sigs.into_iter().unzip(); - let proof = xmss_aggregate_signatures(&xmss_pub_keys, &all_signatures, message_hash, slot).unwrap(); - xmss_verify_aggregated_signatures(&xmss_pub_keys, message_hash, &proof, slot).unwrap(); + let log_inv_rate = 1; + let prox_gaps_conjecture = false; + let proof = xmss_aggregate_signatures( + &xmss_pub_keys, + &all_signatures, + message_hash, + slot, + log_inv_rate, + prox_gaps_conjecture, + ) + .unwrap(); + xmss_verify_aggregated_signatures( + &xmss_pub_keys, + message_hash, + &proof, + slot, + log_inv_rate, + prox_gaps_conjecture, + ) + .unwrap(); } } diff --git a/src/main.rs b/src/main.rs index 0dc01694..09ae81f9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,11 @@ enum Cli { Xmss { #[arg(long)] n_signatures: usize, + #[arg(long, help = "log(1/rate) in WHIR", default_value = "1", short = 'r')] + log_inv_rate: usize, + // TODO use the latest results (i.e. update the conjecture) + #[arg(long, help = "Uses Conjecture 4.12 from WHIR (up to capacity)")] + prox_gaps_conjecture: bool, #[arg(long, help = "Enable tracing")] tracing: bool, }, @@ -17,6 +22,11 @@ enum Cli { Recursion { #[arg(long, default_value = "1", help = "Number of recursive proofs to aggregate")] n: usize, + #[arg(long, help = "log(1/rate) in WHIR", default_value = "2", short = 'r')] + log_inv_rate: usize, + // TODO use the latest results (i.e. update the conjecture) + #[arg(long, help = "Uses Conjecture 4.12 from WHIR (up to capacity)")] + prox_gaps_conjecture: bool, #[arg(long, help = "Enable tracing")] tracing: bool, }, @@ -33,11 +43,21 @@ fn main() { let cli = Cli::parse(); match cli { - Cli::Xmss { n_signatures, tracing } => { - run_xmss_benchmark(n_signatures, tracing); + Cli::Xmss { + n_signatures, + log_inv_rate, + prox_gaps_conjecture, + tracing, + } => { + run_xmss_benchmark(n_signatures, log_inv_rate, prox_gaps_conjecture, tracing); } - Cli::Recursion { n, tracing } => { - run_recursion_benchmark(n, tracing); + Cli::Recursion { + n, + log_inv_rate, + prox_gaps_conjecture, + tracing, + } => { + run_recursion_benchmark(n, log_inv_rate, prox_gaps_conjecture, tracing); } Cli::Poseidon { log_n_perms: log_count, From 5f919813cf5540a3ce6552a50e6d49915f87589e Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 6 Feb 2026 23:29:54 +0400 Subject: [PATCH 26/29] w --- README.md | 2 -- TODO.md | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index bfe5ab03..68cc6cd3 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,6 @@ cargo run --release -- recursion --n 2 | --------------- | --------------- | | 1.10s - 223 KiB | 1.05s - 134 KiB | -(Note about proof size: remaining optimization = [2024/108](https://eprint.iacr.org/2024/108.pdf) section 3.1) - ## Credits - [Plonky3](https://github.com/Plonky3/Plonky3) for its various performant crates diff --git a/TODO.md b/TODO.md index b64b7fc7..08169f62 100644 --- a/TODO.md +++ b/TODO.md @@ -14,6 +14,7 @@ - XMSS: move from toy implem (usefull for benchmark) to a secure implem - Recursion: Remove the few hardcoded constants that depend on the guest execution (cycles etc) - About the ordering of the variables in sumchecks, currently we do as follows: +- [2024/108](https://eprint.iacr.org/2024/108.pdf) section 3.1 [a, b, c, d, e, f, g, h] (1st round of sumcheck) [(a-r).a + r.e, (1-r).b + r.f, (1-r).c + r.g, (1-r).d + r.h] (2nd round of sumcheck) From 5afb964cc6b6780a17cffba70bc8fb0142e526d4 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 7 Feb 2026 00:29:45 +0400 Subject: [PATCH 27/29] wip --- TODO.md | 1 - crates/xmss/params.md | 72 ++++++++++++++++++++++++++++++++++++++++++ crates/xmss/src/lib.rs | 5 --- src/lib.rs | 4 --- 4 files changed, 72 insertions(+), 10 deletions(-) create mode 100644 crates/xmss/params.md diff --git a/TODO.md b/TODO.md index 08169f62..e5633253 100644 --- a/TODO.md +++ b/TODO.md @@ -11,7 +11,6 @@ - Avoid the embedding overhead in logup, when denominators = "c - index" - Proof size: replace all equality checks in the verifier algo by value deduction - Poseidon in 'Compression' mode everywhere (except in 'Sponge' mode? cf. eprint 2014/223) -- XMSS: move from toy implem (usefull for benchmark) to a secure implem - Recursion: Remove the few hardcoded constants that depend on the guest execution (cycles etc) - About the ordering of the variables in sumchecks, currently we do as follows: - [2024/108](https://eprint.iacr.org/2024/108.pdf) section 3.1 diff --git a/crates/xmss/params.md b/crates/xmss/params.md new file mode 100644 index 00000000..71b39aee --- /dev/null +++ b/crates/xmss/params.md @@ -0,0 +1,72 @@ +# XMSS parameters (WIP) + +> **Warning:** The current implementation does not match the [leanSig](https://github.com/leanEthereum/leanSig) paper and does not provide 128-bit security in the Standard Model (though it may still be secure in the ROM/QROM). Expect changes in the future. + +## 1. Field and Hash + +**Field:** KoalaBear, p = 2^31 - 2^24 + 1. Each field element fits in a u32. + +**Hash:** Poseidon2 (width 16) in compression mode: `compress: [F; 16] -> [F; 8]`. Applies the Poseidon2 permutation, adds the input (feed-forward), and returns the first 8 elements. + +**Digest:** 8 field elements (~248 bits). Used for tree nodes, and chain values. + +**Chain step:** `chain_step(x) = compress(x, 0)`. Iterated n times: `iterate_hash(x, n) = chain_step^n(x)`. + +## 2. WOTS + +| Parameter | Symbol | Value | +|---|---|---| +| Chains | V | 40 | +| Winternitz parameter | W | 3 | +| Chain length | CHAIN_LENGTH | 2^W = 8 | +| Verifier chain hashes | NUM_CHAIN_HASHES | 120 | +| Signer chain hashes | TARGET_SUM | 160 (= V*(CHAIN_LENGTH-1) - NUM_CHAIN_HASHES) | +| Grinding chains | V_GRINDING | 3 | +| Message length | MESSAGE_LEN_FE | 9 | +| Randomness length | RANDOMNESS_LEN_FE | 7 | +| Truncated root length | TRUNCATED_MERKLE_ROOT_LEN_FE | 6 | + +### 2.1 Encoding + +Converts (message, randomness, slot, truncated_merkle_root) into 40 chain indices via a **fixed-sum encoding** (indices sum to TARGET_SUM, eliminating the need for checksum chains). + +1. `A = compress(message[0..8], [message[8], randomness[0..7]])` +2. `B = compress(A, [slot_lo, slot_hi, merkle_root[0..6]])` where slot is split into two 16-bit field elements. +3. Reject if any element of B equals -1 (uniformity guard). +4. Extract 24 bits per element of B (little-endian), split into 3-bit chunks, take first 43. +5. Valid iff: first 40 sum to 160, last 3 all equal 7. Otherwise retry with new randomness. + +(Note: adding part of the merkle root to the encoding computation contributes to multi-user security via domain-separation, otherwise the security of the encoding W * (V + V_GRINDING) would degrade bellow 128 bits with multiple users.) + +### 2.2 Keys + +- **Secret key:** 40 random pre-image digests. +- **Public key:** `pk[i] = iterate_hash(pre_image[i], 7)` for each chain. +- **Public key hash:** sequential left fold: `compress(compress(...compress(pk[0], pk[1])..., pk[38]), pk[39])` (39 compressions). + +### 2.3 Sign and Verify + +**Sign:** Find randomness r yielding a valid encoding, then `chain_tip[i] = iterate_hash(pre_image[i], encoding[i])`. Signature = (chain_tips, r). + +**Verify (public key recovery):** Recompute encoding from (message, slot, truncated_root, r), then `recovered_pk[i] = iterate_hash(chain_tip[i], 7 - encoding[i])`. + +## 3. XMSS + +**Tree:** Binary Merkle tree of depth LOG_LIFETIME = 32 (2^32 slots). Nodes = `compress(left, right)`. + +### 3.1 Key Generation + +Inputs: seed (32 bytes), slot range [start, end]. Only WOTS leaves for [start, end] are generated; Merkle nodes outside this range are filled with deterministic random digests (derived from the seed). To an observer, the resulting tree is indistinguishable from a full 2^32-leaf tree. + +**Public key:** the Merkle root (single digest). + + +... +TODO + +## 4. Properties + +- public key size: 31 bytes +- num. hashes at signing: < 2^16 (mostly grinding at encoding) +- num. hashes at verification: 2 (encoding) + NUM_CHAIN_HASHES + V + LOG_LIFETIME = 194 +- sig. size : RANDOMNESS_LEN_FE + 8 * (V + LOG_LIFETIME) = 583 field elements = 2.21 KiB \ No newline at end of file diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 964eae6d..eaa918a6 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -1,8 +1,3 @@ -/* -Toy (unsecure) XMSS, intended for benchmark only. -Production-grade XMSS SOON. -*/ - #![cfg_attr(not(test), warn(unused_crate_dependencies))] use p3_koala_bear::KoalaBear; mod wots; diff --git a/src/lib.rs b/src/lib.rs index ac0362d0..bf72038b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,10 +17,6 @@ pub fn xmss_aggregation_setup_verifier() { pub type F = p3_koala_bear::KoalaBear; -/* -WARNING: Toy XMSS, do not consider this secure (for now)! -*/ - #[cfg(test)] mod tests { use super::*; From 822b02e22162313b3dde8f6cddadaa37602213ad Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 7 Feb 2026 00:32:30 +0400 Subject: [PATCH 28/29] fix #132 --- crates/xmss/src/wots.rs | 8 ++++---- crates/xmss/src/xmss.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 003a1d6e..6e309087 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -1,5 +1,5 @@ use multilinear_toolkit::prelude::*; -use rand::{Rng, RngCore}; +use rand::{CryptoRng, Rng}; use utils::{ToUsize, to_little_endian_bits}; use crate::*; @@ -20,7 +20,7 @@ pub struct WotsSignature { } impl WotsSecretKey { - pub fn random(rng: &mut impl RngCore) -> Self { + pub fn random(rng: &mut impl CryptoRng) -> Self { Self::new(rng.random()) } @@ -40,7 +40,7 @@ impl WotsSecretKey { message: &[F; MESSAGE_LEN_FE], slot: u32, truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], - rng: &mut impl Rng, + rng: &mut impl CryptoRng, ) -> WotsSignature { let (randomness, encoding, _) = find_randomness_for_wots_encoding(message, slot, truncated_merkle_root, rng); WotsSignature { @@ -119,7 +119,7 @@ pub fn find_randomness_for_wots_encoding( message: &[F; MESSAGE_LEN_FE], slot: u32, truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], - rng: &mut impl Rng, + rng: &mut impl CryptoRng, ) -> ([F; RANDOMNESS_LEN_FE], [u8; V], usize) { let mut num_iters = 0; loop { diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index a42c6c1b..e94d5ed9 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -1,5 +1,5 @@ use multilinear_toolkit::prelude::*; -use rand::{Rng, SeedableRng, rngs::StdRng}; +use rand::{CryptoRng, Rng, SeedableRng, rngs::StdRng}; use sha3::{Digest as Sha3Digest, Keccak256}; use crate::*; @@ -112,7 +112,7 @@ pub enum XmssSignatureError { SlotOutOfRange, } -pub fn xmss_sign( +pub fn xmss_sign( rng: &mut R, secret_key: &XmssSecretKey, message_hash: &[F; MESSAGE_LEN_FE], From b1547542d41089f1a45739acc7f4f98f5de1aaf6 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 7 Feb 2026 00:33:33 +0400 Subject: [PATCH 29/29] clippy --- crates/rec_aggregation/src/xmss_aggregate.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 088f133f..a7f98dbe 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -88,7 +88,7 @@ pub fn run_xmss_benchmark(n_signatures: usize, log_inv_rate: usize, prox_gaps_co precompute_dft_twiddles::(1 << 24); let message = (0..MESSAGE_LEN_FE) - .map(|i| F::from_usize(i)) + .map(F::from_usize) .collect::>() .try_into() .unwrap();