diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 2603e7c123..f9036b30e9 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -71,7 +71,7 @@ pub trait FrameworkEval { pub struct FrameworkComponent { eval: C, - trace_locations: TreeVec, + pub trace_locations: TreeVec, } impl FrameworkComponent { diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index b9c7b0866c..1aeef77308 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -140,7 +140,7 @@ impl LookupElements { } pub fn combine(&self, values: &[F]) -> EF where - EF: Copy + Zero + From + From + Mul + Sub, + EF: Copy + Zero + From + From + Mul + Sub, { EF::from(values[0]) + values[1..] diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index ae9ab6b654..cd3f5937e7 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -326,6 +326,7 @@ mod tests { let GkrArtifact { ood_point: r, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; @@ -354,6 +355,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -391,6 +393,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -427,6 +430,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 017948deed..7ee7b268ee 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -559,6 +559,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; @@ -590,6 +591,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -629,6 +631,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -666,6 +669,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 6e6ed2586a..d3d792ccb0 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use num_traits::{One, Zero}; use thiserror::Error; -use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask}; +use super::gkr_verifier::{Gate, GkrArtifact, GkrBatchProof, GkrMask}; use super::mle::{Mle, MleOps}; use super::sumcheck::MultivariatePolyOracle; use super::utils::{eq, random_linear_combination, UnivariatePoly}; @@ -409,6 +409,16 @@ pub fn prove_batch( .collect_vec(); let n_layers = *n_layers_by_instance.iter().max().unwrap(); + let gate_by_instance = input_layer_by_instance + .iter() + .map(|l| match l { + Layer::GrandProduct(_) => Gate::GrandProduct, + Layer::LogUpGeneric { .. } + | Layer::LogUpMultiplicities { .. } + | Layer::LogUpSingles { .. } => Gate::LogUp, + }) + .collect(); + // Evaluate all instance circuits and collect the layer values. let mut layers_by_instance = input_layer_by_instance .into_iter() @@ -502,6 +512,7 @@ pub fn prove_batch( let artifact = GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance: n_layers_by_instance, }; diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index b65ceb162e..322ba1034e 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -143,6 +143,7 @@ pub fn partially_verify_batch( Ok(GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(), }) @@ -162,12 +163,114 @@ pub struct GkrBatchProof { pub struct GkrArtifact { /// Out-of-domain (OOD) point for evaluating columns in the input layer. pub ood_point: Vec, + /// The gate of each instance. + pub gate_by_instance: Vec, /// The claimed evaluation at `ood_point` for each column in the input layer of each instance. pub claims_to_verify_by_instance: Vec>, /// The number of variables that interpolate the input layer of each instance. pub n_variables_by_instance: Vec, } +impl GkrArtifact { + pub fn ood_point(&self, instance_n_variables: usize) -> &[SecureField] { + &self.ood_point[self.ood_point.len() - instance_n_variables..] + } +} + +pub struct LookupArtifactInstanceIter<'proof, 'artifact> { + instance: usize, + gkr_proof: &'proof GkrBatchProof, + gkr_artifact: &'artifact GkrArtifact, +} + +impl<'proof, 'artifact> LookupArtifactInstanceIter<'proof, 'artifact> { + pub fn new(gkr_proof: &'proof GkrBatchProof, gkr_artifact: &'artifact GkrArtifact) -> Self { + Self { + instance: 0, + gkr_proof, + gkr_artifact, + } + } +} + +impl<'proof, 'artifact> Iterator for LookupArtifactInstanceIter<'proof, 'artifact> { + type Item = LookupArtifactInstance; + + fn next(&mut self) -> Option { + if self.instance >= self.gkr_proof.output_claims_by_instance.len() { + return None; + } + + let instance = self.instance; + let input_n_variables = self.gkr_artifact.n_variables_by_instance[instance]; + let eval_point = self.gkr_artifact.ood_point(input_n_variables).to_vec(); + let output_claim = &*self.gkr_proof.output_claims_by_instance[instance]; + let input_claims = &*self.gkr_artifact.claims_to_verify_by_instance[instance]; + let gate = self.gkr_artifact.gate_by_instance[instance]; + + let res = Some(match gate { + Gate::LogUp => { + let [numerator, denominator] = output_claim.try_into().unwrap(); + let claimed_sum = Fraction::new(numerator, denominator); + let [input_numerators_claim, input_denominators_claim] = + input_claims.try_into().unwrap(); + + LookupArtifactInstance::LogUp(LogUpArtifactInstance { + eval_point, + input_n_variables, + input_numerators_claim, + input_denominators_claim, + claimed_sum, + }) + } + Gate::GrandProduct => { + let [claimed_product] = output_claim.try_into().unwrap(); + let [input_claim] = input_claims.try_into().unwrap(); + + LookupArtifactInstance::GrandProduct(GrandProductArtifactInstance { + eval_point, + input_n_variables, + input_claim, + claimed_product, + }) + } + }); + + self.instance += 1; + res + } +} + +// TODO: Consider making the GKR artifact just a Vec. +pub enum LookupArtifactInstance { + GrandProduct(GrandProductArtifactInstance), + LogUp(LogUpArtifactInstance), +} + +pub struct GrandProductArtifactInstance { + /// GKR input layer eval point. + pub eval_point: Vec, + /// Number of variables the MLE in the GKR input layer had. + pub input_n_variables: usize, + /// Claimed input MLE evaluation at `eval_point`. + pub input_claim: SecureField, + /// Output claim from the circuit. + pub claimed_product: SecureField, +} + +pub struct LogUpArtifactInstance { + /// GKR input layer eval point. + pub eval_point: Vec, + /// Number of variables the MLEs in the GKR input layer had. + pub input_n_variables: usize, + /// Claimed input numerators MLE evaluation at `eval_point`. + pub input_numerators_claim: SecureField, + /// Claimed input denominators MLE evaluation at `eval_point`. + pub input_denominators_claim: SecureField, + /// Output claim from the circuit. + pub claimed_sum: Fraction, +} + /// Defines how a circuit operates locally on two input rows to produce a single output row. /// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure. /// @@ -176,7 +279,7 @@ pub struct GkrArtifact { /// circuit) GKR prover implementations. /// /// [Thaler13]: https://eprint.iacr.org/2013/351.pdf -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Gate { LogUp, GrandProduct, @@ -305,11 +408,13 @@ mod tests { let GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance, } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]); + assert_eq!(gate_by_instance, [Gate::LogUp, Gate::LogUp]); assert_eq!(proof.output_claims_by_instance.len(), 2); assert_eq!(claims_to_verify_by_instance.len(), 2); assert_eq!(proof.output_claims_by_instance[0], &[product0]); @@ -338,11 +443,13 @@ mod tests { let GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance, } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]); + assert_eq!(gate_by_instance, [Gate::LogUp, Gate::LogUp]); assert_eq!(proof.output_claims_by_instance.len(), 2); assert_eq!(claims_to_verify_by_instance.len(), 2); assert_eq!(proof.output_claims_by_instance[0], &[product0]); diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 58bb8633b9..651f44753a 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -61,9 +61,9 @@ impl BlakeStatement0 { } pub struct AllElements { - blake_elements: BlakeElements, - round_elements: RoundElements, - xor_elements: BlakeXorElements, + pub blake_elements: BlakeElements, + pub round_elements: RoundElements, + pub xor_elements: BlakeXorElements, } impl AllElements { pub fn draw(channel: &mut impl Channel) -> Self { @@ -223,7 +223,7 @@ where { assert!(log_size >= LOG_N_LANES); assert_eq!( - ROUND_LOG_SPLIT.map(|x| (1 << x)).into_iter().sum::() as usize, + ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::(), N_ROUNDS ); @@ -240,7 +240,7 @@ where span.exit(); // Prepare inputs. - let blake_inputs = (0..(1 << (log_size - LOG_N_LANES))) + let blake_inputs = (0..1 << (log_size - LOG_N_LANES)) .map(|i| { let v = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j) as u32)); 16]; let m = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; @@ -282,18 +282,15 @@ where // Trace commitment. let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( - chain![ - scheduler_trace, - round_traces.into_iter().flatten(), - xor_trace12, - xor_trace9, - xor_trace8, - xor_trace7, - xor_trace4, - ] - .collect_vec(), - ); + tree_builder.extend_evals(chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ]); tree_builder.commit(channel); span.exit(); diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index 6fbe6d81be..06a2eb6a0e 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -12,28 +12,28 @@ use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::FieldExpOps; -mod air; -mod round; -mod scheduler; -mod xor_table; +pub mod air; +pub mod round; +pub mod scheduler; +pub mod xor_table; -const STATE_SIZE: usize = 16; -const MESSAGE_SIZE: usize = 16; -const N_FELTS_IN_U32: usize = 2; -const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; +pub const STATE_SIZE: usize = 16; +pub const MESSAGE_SIZE: usize = 16; +pub const N_FELTS_IN_U32: usize = 2; +pub const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; // Parameters for Blake2s. Change these for blake3. -const N_ROUNDS: usize = 10; +pub const N_ROUNDS: usize = 10; /// A splitting N_ROUNDS into several powers of 2. -const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; +pub const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; #[derive(Default)] -struct XorAccums { - xor12: XorAccumulator<12, 4>, - xor9: XorAccumulator<9, 2>, - xor8: XorAccumulator<8, 2>, - xor7: XorAccumulator<7, 2>, - xor4: XorAccumulator<4, 0>, +pub struct XorAccums { + pub xor12: XorAccumulator<12, 4>, + pub xor9: XorAccumulator<9, 2>, + pub xor8: XorAccumulator<8, 2>, + pub xor7: XorAccumulator<7, 2>, + pub xor4: XorAccumulator<4, 0>, } impl XorAccums { fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) { @@ -50,11 +50,11 @@ impl XorAccums { #[derive(Clone)] pub struct BlakeXorElements { - xor12: XorElements, - xor9: XorElements, - xor8: XorElements, - xor7: XorElements, - xor4: XorElements, + pub xor12: XorElements, + pub xor9: XorElements, + pub xor8: XorElements, + pub xor7: XorElements, + pub xor4: XorElements, } impl BlakeXorElements { fn draw(channel: &mut impl Channel) -> Self { @@ -75,7 +75,7 @@ impl BlakeXorElements { xor4: XorElements::dummy(), } } - fn get(&self, w: u32) -> &XorElements { + pub fn get(&self, w: u32) -> &XorElements { match w { 12 => &self.xor12, 9 => &self.xor9, @@ -89,7 +89,7 @@ impl BlakeXorElements { /// Utility for representing a u32 as two field elements, for constraint evaluation. #[derive(Clone, Copy, Debug)] -struct Fu32 +pub struct Fu32 where F: FieldExpOps + Copy @@ -112,7 +112,7 @@ where + Sub + Mul, { - fn to_felts(self) -> [F; 2] { + pub fn to_felts(self) -> [F; 2] { [self.l, self.h] } } diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index 9bddcdd40b..6075e49820 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -23,19 +23,19 @@ use crate::examples::blake::{to_felts, XorAccums, N_ROUND_INPUT_FELTS, STATE_SIZ pub struct BlakeRoundLookupData { /// A vector of (w, [a_col, b_col, c_col]) for each xor lookup. /// w is the xor width. c_col is the xor col of a_col and b_col. - xor_lookups: Vec<(u32, [BaseColumn; 3])>, + pub xor_lookups: Vec<(u32, [BaseColumn; 3])>, /// A column of round lookup values (v_in, v_out, m). - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], + pub round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], } pub struct TraceGenerator { - log_size: u32, - trace: Vec, - xor_lookups: Vec<(u32, [BaseColumn; 3])>, - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], + pub log_size: u32, + pub trace: Vec, + pub xor_lookups: Vec<(u32, [BaseColumn; 3])>, + pub round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], } impl TraceGenerator { - fn new(log_size: u32) -> Self { + pub fn new(log_size: u32) -> Self { assert!(log_size >= LOG_N_LANES); let trace = (0..blake_round_info().mask_offsets[0].len()) .map(|_| unsafe { Col::::uninitialized(1 << log_size) }) @@ -50,7 +50,7 @@ impl TraceGenerator { } } - fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { + pub fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { TraceGeneratorRow { gen: self, col_index: 0, @@ -61,7 +61,7 @@ impl TraceGenerator { } /// Trace generator for the constraints defined at [`super::constraints::BlakeRoundEval`] -struct TraceGeneratorRow<'a> { +pub struct TraceGeneratorRow<'a> { gen: &'a mut TraceGenerator, col_index: usize, vec_row: usize, @@ -79,7 +79,7 @@ impl<'a> TraceGeneratorRow<'a> { self.append_felt(val >> 16); } - fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { + pub fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { let input_v = v; v.iter().for_each(|s| { self.append_u32(*s); @@ -215,7 +215,7 @@ pub fn generate_trace( let _span = span!(Level::INFO, "Round Generation").entered(); let mut generator = TraceGenerator::new(log_size); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let mut row_gen = generator.gen_row(vec_row); let BlakeRoundInput { v, m } = inputs.get(vec_row).copied().unwrap_or_default(); row_gen.generate(v, m); diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index cf83113393..49d798f4b0 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -1,7 +1,10 @@ mod constraints; mod gen; -pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; +pub use gen::{ + generate_interaction_trace, generate_trace, BlakeRoundInput, BlakeRoundLookupData, + TraceGenerator, TraceGeneratorRow, +}; use num_traits::Zero; use super::{BlakeXorElements, N_ROUND_INPUT_FELTS}; diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 63b3cf6961..33c044879f 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -57,7 +57,7 @@ pub fn eval_blake_scheduler_constraints( logup.finalize(eval); } -fn eval_next_u32(eval: &mut E) -> Fu32 { +pub fn eval_next_u32(eval: &mut E) -> Fu32 { let l = eval.next_trace_mask(); let h = eval.next_trace_mask(); Fu32 { l, h } diff --git a/crates/prover/src/examples/blake/scheduler/gen.rs b/crates/prover/src/examples/blake/scheduler/gen.rs index cd6a99b2ff..ae3569ed0e 100644 --- a/crates/prover/src/examples/blake/scheduler/gen.rs +++ b/crates/prover/src/examples/blake/scheduler/gen.rs @@ -58,7 +58,7 @@ pub fn gen_trace( .map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) .collect_vec(); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let mut col_index = 0; let mut write_u32_array = |x: [u32x16; STATE_SIZE], col_index: &mut usize| { @@ -125,11 +125,11 @@ pub fn gen_interaction_trace( let mut logup_gen = LogupTraceGenerator::new(log_size); - for [l0, l1] in lookup_data.round_lookups.array_chunks::<2>() { + for [l0, l1] in lookup_data.round_lookups.array_chunks() { let mut col_gen = logup_gen.new_col(); #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let p0: PackedSecureField = round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); let p1: PackedSecureField = @@ -145,7 +145,7 @@ pub fn gen_interaction_trace( // with the entire blake lookup. let mut col_gen = logup_gen.new_col(); #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let p_blake: PackedSecureField = blake_lookup_elements.combine( &lookup_data .blake_lookups diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index e8a8c32f3a..3035626c8f 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -2,7 +2,8 @@ mod constraints; mod gen; use constraints::eval_blake_scheduler_constraints; -pub use gen::{gen_interaction_trace, gen_trace, BlakeInput}; +pub use constraints::eval_next_u32; +pub use gen::{gen_interaction_trace, gen_trace, BlakeInput, BlakeSchedulerLookupData}; use num_traits::Zero; use super::round::RoundElements; diff --git a/crates/prover/src/examples/blake/xor_table/gen.rs b/crates/prover/src/examples/blake/xor_table/gen.rs index 195a6ca464..46309e640f 100644 --- a/crates/prover/src/examples/blake/xor_table/gen.rs +++ b/crates/prover/src/examples/blake/xor_table/gen.rs @@ -74,7 +74,7 @@ pub fn generate_interaction_trace( // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (column_bits::() - LOG_N_LANES)) { + for vec_row in 0..1 << (column_bits::() - LOG_N_LANES) { // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. // Extract al, blh from vec_row. let al = vec_row >> (limb_bits - LOG_N_LANES); diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index bd74ea040f..9f344b18db 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -17,7 +17,9 @@ use std::simd::u32x16; use itertools::Itertools; use num_traits::Zero; -pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; +pub use r#gen::{ + generate_constant_trace, generate_interaction_trace, generate_trace, XorTableLookupData, +}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; @@ -37,7 +39,7 @@ pub fn trace_sizes() -> TreeVec()) } -const fn limb_bits() -> u32 { +pub const fn limb_bits() -> u32 { ELEM_BITS - EXPAND_BITS } pub const fn column_bits() -> u32 { diff --git a/crates/prover/src/examples/blake_gkr/air.rs b/crates/prover/src/examples/blake_gkr/air.rs new file mode 100644 index 0000000000..bd79eb5a15 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/air.rs @@ -0,0 +1,301 @@ +use std::array; +use std::simd::u32x16; + +use itertools::{chain, multiunzip, Itertools}; +use tracing::{span, Level}; + +use super::gkr_lookups::MleCoeffColumnOracleAccumulator; +use super::scheduler::BlakeSchedulerComponent; +use super::xor_table::XorLookupArtifacts; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::BackendForChannel; +use crate::core::channel::{Channel, MerkleChannel}; +use crate::core::lookups::gkr_prover::prove_batch; +use crate::core::lookups::gkr_verifier::{ + GkrBatchProof, LookupArtifactInstance, LookupArtifactInstanceIter, +}; +use crate::core::pcs::{CommitmentSchemeProver, PcsConfig}; +use crate::core::poly::circle::{CanonicCoset, PolyOps}; +use crate::core::prover::StarkProof; +use crate::core::vcs::ops::MerkleHasher; +use crate::examples::blake::air::AllElements; +use crate::examples::blake::scheduler::{self as air_scheduler, BlakeInput}; +use crate::examples::blake::{ + round as air_round, xor_table as air_xor_table, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, +}; +use crate::examples::blake_gkr::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use crate::examples::blake_gkr::round::RoundLookupArtifact; +use crate::examples::blake_gkr::scheduler::SchedulerLookupArtifact; +use crate::examples::blake_gkr::{round, scheduler, xor_table}; + +pub struct BlakeClaim { + log_size: u32, +} + +impl BlakeClaim { + fn mix_into(&self, channel: &mut impl Channel) { + // TODO(spapini): Do this better. + channel.mix_nonce(self.log_size as u64); + } +} + +pub struct BlakeProof { + pub claim: BlakeClaim, + pub gkr_proof: GkrBatchProof, + pub stark_proof: StarkProof, +} + +pub struct BlakeLookupArtifacts { + scheduler: SchedulerLookupArtifact, + /// `|ROUND_LOG_SPLIT|` many round artifacts. + rounds: Vec, + xor: XorLookupArtifacts, +} + +impl BlakeLookupArtifacts { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + scheduler: SchedulerLookupArtifact::new_from_iter(&mut iter), + rounds: ROUND_LOG_SPLIT + .iter() + .map(|_| RoundLookupArtifact::new_from_iter(&mut iter)) + .collect(), + xor: XorLookupArtifacts::new_from_iter(&mut iter), + } + } + + pub fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + let Self { + scheduler, + rounds, + xor, + } = self; + scheduler.verify_succinct_mle_claims()?; + for round in rounds { + round.verify_succinct_mle_claims()?; + } + xor.verify_succinct_mle_claims()?; + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { + scheduler, + rounds, + xor, + } = self; + scheduler.accumulate_mle_eval_iop_claims(acc); + rounds + .iter() + .for_each(|round| round.accumulate_mle_eval_iop_claims(acc)); + xor.accumulate_mle_eval_iop_claims(acc); + } +} + +#[derive(Debug)] +pub struct InvalidClaimError; + +pub struct BlakeComponents { + scheduler_component: BlakeSchedulerComponent, +} + +impl BlakeComponents { + pub fn new(_claim: &BlakeClaim, _gkr_artifact: BlakeLookupArtifacts) -> Self { + todo!() + } + + pub fn accumulate_mle_coeff_col_oracles<'this: 'acc, 'acc>( + &'this self, + acc: &mut MleCoeffColumnOracleAccumulator<'acc>, + ) { + acc.accumulate(&self.scheduler_component); + } +} + +pub fn prove_blake(log_size: u32, config: PcsConfig) -> BlakeProof +where + SimdBackend: BackendForChannel, +{ + assert!(log_size >= LOG_N_LANES); + assert_eq!( + ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::(), + N_ROUNDS + ); + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + const XOR_TABLE_MAX_LOG_SIZE: u32 = 16; + let max_log_size = + (log_size + *ROUND_LOG_SPLIT.iter().max().unwrap()).max(XOR_TABLE_MAX_LOG_SIZE); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(max_log_size + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Prepare inputs. + let blake_inputs = (0..1 << (log_size - LOG_N_LANES)) + .map(|i| { + let v = [u32x16::from_array(array::from_fn(|j| (i + 2 * j) as u32)); 16]; + let m = [u32x16::from_array(array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; + BlakeInput { v, m } + }) + .collect_vec(); + + // Setup protocol. + let channel = &mut MC::C::default(); + let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); + + let span = span!(Level::INFO, "Trace").entered(); + + // Scheduler. + let (scheduler_trace, scheduler_lookup_data, round_inputs) = + air_scheduler::gen_trace(log_size, &blake_inputs); + + // Rounds. + let mut xor_accums = XorAccums::default(); + let mut rest = &round_inputs[..]; + // Split round inputs to components, according to [ROUND_LOG_SPLIT]. + let (round_traces, round_lookup_datas): (Vec<_>, Vec<_>) = + multiunzip(ROUND_LOG_SPLIT.map(|l| { + let (cur_inputs, r) = rest.split_at(1 << (log_size - LOG_N_LANES + l)); + rest = r; + air_round::generate_trace(log_size + l, cur_inputs, &mut xor_accums) + })); + + // Xor tables. + let (xor_trace12, xor_lookup_data12) = air_xor_table::generate_trace(xor_accums.xor12); + let (xor_trace9, xor_lookup_data9) = air_xor_table::generate_trace(xor_accums.xor9); + let (xor_trace8, xor_lookup_data8) = air_xor_table::generate_trace(xor_accums.xor8); + let (xor_trace7, xor_lookup_data7) = air_xor_table::generate_trace(xor_accums.xor7); + let (xor_trace4, xor_lookup_data4) = air_xor_table::generate_trace(xor_accums.xor4); + + // Claim. + let claim = BlakeClaim { log_size }; + claim.mix_into(channel); + + // Trace commitment. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ]); + tree_builder.commit(channel); + span.exit(); + + // Draw lookup element. + let all_elements = AllElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let mut lookup_input_layers = Vec::new(); + let mut mle_eval_at_point_collection = MleCollection::default(); + + lookup_input_layers.extend(scheduler::generate_lookup_instances( + log_size, + scheduler_lookup_data, + &all_elements.round_elements, + &all_elements.blake_elements, + &mut mle_eval_at_point_collection, + )); + + ROUND_LOG_SPLIT + .iter() + .zip(round_lookup_datas) + .for_each(|(l, lookup_data)| { + lookup_input_layers.extend(round::generate_lookup_instances( + log_size + l, + lookup_data, + &all_elements.xor_elements, + &all_elements.round_elements, + &mut mle_eval_at_point_collection, + )); + }); + + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data12, + &all_elements.xor_elements.xor12, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data9, + &all_elements.xor_elements.xor9, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data8, + &all_elements.xor_elements.xor8, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data7, + &all_elements.xor_elements.xor7, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data4, + &all_elements.xor_elements.xor4, + &mut mle_eval_at_point_collection, + )); + + let (gkr_proof, gkr_artifact) = prove_batch(channel, lookup_input_layers); + let mle_acc_coeff = channel.draw_felt(); + let _mles = mle_eval_at_point_collection.random_linear_combine_by_n_variables(mle_acc_coeff); + span.exit(); + + // TODO(andrew): Consider unifying new_from_iter, verify_succinct_mle_claims, + // accumulate_mle_eval_iop_claims. + let mut lookup_instances_iter = LookupArtifactInstanceIter::new(&gkr_proof, &gkr_artifact); + let blake_lookup_artifacts = BlakeLookupArtifacts::new_from_iter(&mut lookup_instances_iter); + assert!(lookup_instances_iter.next().is_none()); + blake_lookup_artifacts.verify_succinct_mle_claims().unwrap(); + let mut mle_eval_iop_acc = MleClaimAccumulator::new(mle_acc_coeff); + blake_lookup_artifacts.accumulate_mle_eval_iop_claims(&mut mle_eval_iop_acc); + let mut _mle_claim_by_n_variables = mle_eval_iop_acc.finalize(); + + #[cfg(test)] + for mle in &_mles { + let n_variables = mle.n_variables(); + let eval_point = gkr_artifact.ood_point(n_variables); + let claim = _mle_claim_by_n_variables[n_variables].take().unwrap(); + debug_assert_eq!(claim, mle.eval_at_point(eval_point)); + } + + // let mle_coeff_col_acc_by_n_variables = + + todo!() +} + +#[cfg(test)] +mod tests { + use std::env; + + use crate::core::pcs::PcsConfig; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::blake_gkr::air::prove_blake; + + // Note: this test is slow. Only run in release. + #[cfg_attr(not(feature = "slow-tests"), ignore)] + #[test_log::test] + fn test_simd_blake_gkr_prove() { + // Get from environment variable: + let log_n_instances = env::var("LOG_N_INSTANCES") + .unwrap_or_else(|_| "6".to_string()) + .parse::() + .unwrap(); + let config = PcsConfig::default(); + + // Prove. + let _proof = prove_blake::(log_n_instances, config); + + // Verify. + // verify_blake::(proof, config).unwrap(); + } +} diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs similarity index 84% rename from crates/prover/src/examples/xor/gkr_lookups/accumulation.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs index a63b62503a..6a65de3ea0 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs @@ -2,7 +2,7 @@ use std::iter::zip; use std::ops::{AddAssign, Mul}; use educe::Educe; -use num_traits::One; +use num_traits::{One, Zero}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::Backend; @@ -16,7 +16,7 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; /// Max number of variables for multilinear polynomials that get compiled into a univariate /// IOP for multilinear eval at point. -pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; +pub const MAX_MLE_N_VARIABLES: usize = (M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR) as usize; /// Collection of [`Mle`]s grouped by their number of variables. pub struct MleCollection { @@ -92,7 +92,7 @@ pub fn combine + Copy, F: Copy>( impl Default for MleCollection { fn default() -> Self { Self { - mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES as usize + 1], + mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES + 1], } } } @@ -135,6 +135,32 @@ impl DynMle { } } +/// Accumulates claims of multilinear polynomials, grouped by their number of variables. +// TODO(andrew): Consider group by eval point to make sure everything done correctly. +pub struct MleClaimAccumulator { + acc_coeff: SecureField, + acc_by_n_variables: Vec>, +} + +impl MleClaimAccumulator { + pub fn new(acc_coeff: SecureField) -> Self { + Self { + acc_coeff, + acc_by_n_variables: vec![None; MAX_MLE_N_VARIABLES + 1], + } + } + + pub fn accumulate(&mut self, n_variables: usize, evaluation: SecureField) { + let acc = self.acc_by_n_variables[n_variables].get_or_insert_with(SecureField::zero); + *acc = *acc * self.acc_coeff + evaluation; + } + + /// Returns a mapping of number of variables to claim accumulation. + pub fn finalize(self) -> Vec> { + self.acc_by_n_variables + } +} + #[cfg(test)] mod tests { use std::iter::repeat; @@ -146,7 +172,7 @@ mod tests { use crate::core::fields::qm31::SecureField; use crate::core::fields::Field; use crate::core::lookups::mle::{Mle, MleOps}; - use crate::examples::xor::gkr_lookups::accumulation::MleCollection; + use crate::examples::blake_gkr::gkr_lookups::accumulation::MleCollection; #[test] fn random_linear_combine_by_n_variables() { diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs similarity index 99% rename from crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs index eb8dc163f2..db12bf76a1 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs @@ -1,7 +1,4 @@ //! Multilinear extension (MLE) eval at point constraints. -// TODO(andrew): Remove in downstream PR. -#![allow(dead_code)] - use std::iter::zip; use itertools::{chain, zip_eq, Itertools}; @@ -742,8 +739,8 @@ mod tests { use crate::core::prover::{prove, verify, VerificationError}; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; - use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; + use crate::examples::blake_gkr::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; + use crate::examples::blake_gkr::gkr_lookups::mle_eval::eval_step_selector_with_offset; #[test] fn mle_eval_prover_component() -> Result<(), VerificationError> { @@ -812,7 +809,6 @@ mod tests { const N_VARIABLES: usize = 8; const COEFFS_COL_TRACE: usize = 0; const MLE_EVAL_TRACE: usize = 1; - const CONST_TRACE: usize = 2; const LOG_EXPAND: u32 = 1; // Create the test MLE. let mut rng = SmallRng::seed_from_u64(0); @@ -1121,7 +1117,7 @@ mod tests { use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::ColumnVec; - use crate::examples::xor::gkr_lookups::mle_eval::MleCoeffColumnOracle; + use crate::examples::blake_gkr::gkr_lookups::mle_eval::MleCoeffColumnOracle; pub type MleCoeffColumnComponent = FrameworkComponent; diff --git a/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs new file mode 100644 index 0000000000..f736a93cac --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs @@ -0,0 +1,53 @@ +use mle_eval::MleCoeffColumnOracle; + +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::circle::CirclePoint; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; + +pub mod accumulation; +pub mod mle_eval; + +// TODO(andrew): Try come up with less verbose name. +pub trait AccumulatedMleCoeffColumnOracle { + fn accumulate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ); +} + +// TODO(andrew): Try come up with less verbose name. +pub struct MleCoeffColumnOracleAccumulator<'a> { + acc_coeff: SecureField, + oracles: Vec<&'a dyn AccumulatedMleCoeffColumnOracle>, +} + +impl<'a> MleCoeffColumnOracleAccumulator<'a> { + pub fn new(acc_coeff: SecureField) -> Self { + Self { + acc_coeff, + oracles: Vec::new(), + } + } + + pub fn accumulate<'b: 'a>(&mut self, oracle: &'b dyn AccumulatedMleCoeffColumnOracle) { + self.oracles.push(oracle) + } +} + +impl<'a> MleCoeffColumnOracle for MleCoeffColumnOracleAccumulator<'a> { + fn evaluate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField { + let mut acc = PointEvaluationAccumulator::new(self.acc_coeff); + for oracle in &self.oracles { + oracle.accumulate_at_point(point, mask, &mut acc); + } + acc.finalize() + } +} diff --git a/crates/prover/src/examples/blake_gkr/mod.rs b/crates/prover/src/examples/blake_gkr/mod.rs new file mode 100644 index 0000000000..a1f767c6b6 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/mod.rs @@ -0,0 +1,5 @@ +pub mod air; +pub mod gkr_lookups; +pub mod round; +pub mod scheduler; +pub mod xor_table; diff --git a/crates/prover/src/examples/blake_gkr/round.rs b/crates/prover/src/examples/blake_gkr/round.rs new file mode 100644 index 0000000000..98897c3db5 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/round.rs @@ -0,0 +1,110 @@ +use num_traits::One; +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::examples::blake::round::{BlakeRoundLookupData, RoundElements, TraceGenerator}; +use crate::examples::blake::BlakeXorElements; + +pub struct RoundLookupArtifact { + pub round: LogUpArtifactInstance, + pub xors: Vec, +} + +impl RoundLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + let xors = (0..n_xor_lookups()) + .map(|_| match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }) + .collect(); + + let round = match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }; + + Self { round, xors } + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { round, xors } = self; + + for xor in xors { + acc.accumulate(xor.input_n_variables, xor.input_denominators_claim); + } + + acc.accumulate(round.input_n_variables, round.input_denominators_claim); + } + + pub fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + let Self { round, xors } = self; + + for xor_artifact in xors { + if !xor_artifact.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + } + + if !round.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + + Ok(()) + } +} + +/// Returns an ordered list of all XOR lookup types the round component uses. +fn n_xor_lookups() -> usize { + // Create a dummy trace to extract the structural xor lookup information. + let mut trace_generator = TraceGenerator::new(LOG_N_LANES); + let mut row = trace_generator.gen_row(0); + row.generate(Default::default(), Default::default()); + trace_generator.xor_lookups.len() +} + +pub fn generate_lookup_instances( + log_size: u32, + lookup_data: BlakeRoundLookupData, + xor_lookup_elements: &BlakeXorElements, + round_lookup_elements: &RoundElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Generate round interaction trace").entered(); + let size = 1 << log_size; + let mut round_lookup_layers = Vec::new(); + + for (w, l) in &lookup_data.xor_lookups { + let lookup_elements = xor_lookup_elements.get(*w); + let mut denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = lookup_elements.combine(&l.each_ref().map(|l| l.data[vec_row])); + denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { denominators }); + } + + // Blake round lookup. + let mut round_denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = round_lookup_elements + .combine(&lookup_data.round_lookup.each_ref().map(|l| l.data[vec_row])); + round_denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(round_denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { + denominators: round_denominators, + }); + + round_lookup_layers +} diff --git a/crates/prover/src/examples/blake_gkr/scheduler.rs b/crates/prover/src/examples/blake_gkr/scheduler.rs new file mode 100644 index 0000000000..b1dd50dd38 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/scheduler.rs @@ -0,0 +1,209 @@ +use std::array; + +use itertools::{chain, Itertools}; +use num_traits::{One, Zero}; +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use super::gkr_lookups::AccumulatedMleCoeffColumnOracle; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator}; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::backend::simd::blake2s::SIGMA; +use crate::core::backend::simd::column::{BaseColumn, SecureColumn}; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; +use crate::examples::blake::round::RoundElements; +use crate::examples::blake::scheduler::{eval_next_u32, BlakeElements, BlakeSchedulerLookupData}; +use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; + +pub type BlakeSchedulerComponent = FrameworkComponent; + +pub struct BlakeSchedulerEval { + pub log_size: u32, + pub blake_lookup_elements: BlakeElements, + pub round_lookup_elements: RoundElements, +} + +impl FrameworkEval for BlakeSchedulerEval { + fn log_size(&self) -> u32 { + self.log_size + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + } + + fn evaluate(&self, mut eval: E) -> E { + let _ = SchedulerEvals::new(&mut eval); + eval + } +} + +impl AccumulatedMleCoeffColumnOracle for BlakeSchedulerComponent { + fn accumulate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ) { + // Create dummy point evaluator just to extract the value we need from the mask + let mut _accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let mut eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut _accumulator, + SecureField::one(), + ); + + let SchedulerEvals { messages, states } = SchedulerEvals::new(&mut eval); + + // Schedule. + for i in 0..N_ROUNDS { + let input_state = &states[i]; + let output_state = &states[i + 1]; + let round_messages = SIGMA[i].map(|j| messages[j as usize]); + // Use triplet in round lookup. + let lookup_values = &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + round_messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(); + let denoms_mle_coeff_col_eval = self.round_lookup_elements.combine(lookup_values); + acc.accumulate(denoms_mle_coeff_col_eval); + } + + let input_state = &states[0]; + let output_state = &states[N_ROUNDS]; + let lookup_values = &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(); + let denoms_mle_coeff_col_eval = self.blake_lookup_elements.combine(lookup_values); + acc.accumulate(denoms_mle_coeff_col_eval); + } +} + +struct SchedulerEvals { + messages: [Fu32; STATE_SIZE], + states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1], +} + +impl SchedulerEvals { + fn new(eval: &mut E) -> Self { + Self { + messages: array::from_fn(|_| eval_next_u32(eval)), + states: array::from_fn(|_| array::from_fn(|_| eval_next_u32(eval))), + } + } +} + +pub struct SchedulerLookupArtifact { + scheduler: LogUpArtifactInstance, + rounds: [LogUpArtifactInstance; N_ROUNDS], +} + +impl SchedulerLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + let rounds = array::from_fn(|_| match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }); + + let scheduler = match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }; + + Self { scheduler, rounds } + } + + pub fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + let Self { scheduler, rounds } = self; + + // TODO(andrew): Consider checking the n_variables is correct. + // if !self.scheduler.input_numerators_claim.is_one() { + if !scheduler.input_numerators_claim.is_zero() { + return Err(InvalidClaimError); + } + + for round in rounds { + if !round.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + } + + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { scheduler, rounds } = self; + + for round in rounds { + acc.accumulate(round.input_n_variables, round.input_denominators_claim); + } + + // TODO: Note `n_variables` is not verified. Probably fine since if the prover gives wrong + // info they'll be caught. Can panic though if the n_variables is too high. Consider + // checking the number of GKR layers in the verifier is less than + // LOG_CIRCLE_ORDER-LOG_BLOWUP-LOG_EXPAND. + acc.accumulate( + scheduler.input_n_variables, + scheduler.input_denominators_claim, + ); + } +} + +pub fn generate_lookup_instances( + log_size: u32, + lookup_data: BlakeSchedulerLookupData, + round_lookup_elements: &RoundElements, + blake_lookup_elements: &BlakeElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Generate scheduler interaction trace").entered(); + let size = 1 << log_size; + let mut round_lookup_layers = Vec::new(); + + for l0 in &lookup_data.round_lookups { + let mut denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); + denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { denominators }) + } + + // Blake hash lookup. + let blake_numers = Mle::::new(BaseColumn::zeros(size)); + let mut blake_denoms = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let blake_denom: PackedSecureField = blake_lookup_elements.combine( + &lookup_data + .blake_lookups + .each_ref() + .map(|l| l.data[vec_row]), + ); + blake_denoms.data[vec_row] = blake_denom; + } + collection_for_univariate_iop.push(blake_denoms.clone()); + round_lookup_layers.push(Layer::LogUpMultiplicities { + numerators: blake_numers, + denominators: blake_denoms, + }); + + round_lookup_layers +} diff --git a/crates/prover/src/examples/blake_gkr/xor_table.rs b/crates/prover/src/examples/blake_gkr/xor_table.rs new file mode 100644 index 0000000000..83eb80fb0b --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/xor_table.rs @@ -0,0 +1,158 @@ +use std::array; +use std::simd::u32x16; + +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::examples::blake::xor_table::{column_bits, limb_bits, XorElements, XorTableLookupData}; + +pub struct XorLookupArtifacts { + xor12: XorLookupArtifact<12, 4>, + xor9: XorLookupArtifact<9, 2>, + xor8: XorLookupArtifact<8, 2>, + xor7: XorLookupArtifact<7, 2>, + xor4: XorLookupArtifact<4, 0>, +} + +impl XorLookupArtifacts { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + xor12: XorLookupArtifact::new_from_iter(&mut iter), + xor9: XorLookupArtifact::new_from_iter(&mut iter), + xor8: XorLookupArtifact::new_from_iter(&mut iter), + xor7: XorLookupArtifact::new_from_iter(&mut iter), + xor4: XorLookupArtifact::new_from_iter(&mut iter), + } + } + + pub fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + let Self { + xor12, + xor9, + xor8, + xor7, + xor4, + } = self; + + xor12.verify_succinct_mle_claims()?; + xor9.verify_succinct_mle_claims()?; + xor8.verify_succinct_mle_claims()?; + xor7.verify_succinct_mle_claims()?; + xor4.verify_succinct_mle_claims()?; + + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { + xor12, + xor9, + xor8, + xor7, + xor4, + } = self; + xor12.accumulate_mle_eval_iop_claims(acc); + xor9.accumulate_mle_eval_iop_claims(acc); + xor8.accumulate_mle_eval_iop_claims(acc); + xor7.accumulate_mle_eval_iop_claims(acc); + xor4.accumulate_mle_eval_iop_claims(acc); + } +} + +pub struct XorLookupArtifact { + /// `2^(2*EXPAND_BITS)` many LogUp instances. + artifacts: Vec, +} + +impl XorLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + artifacts: (0..1 << (2 * EXPAND_BITS)) + .map(|_| match iter.next() { + // TODO: check input MLEs have expected number of variables. + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }) + .collect(), + } + } + + fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + for (i, _artifact) in self.artifacts.iter().enumerate() { + // Extract ah, bh from column index. + let _ah = i as u32 >> EXPAND_BITS; + let _bh = i as u32 & ((1 << EXPAND_BITS) - 1); + + let _ = 5; + } + + Ok(()) + } + + fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { artifacts } = self; + for artifact in artifacts { + acc.accumulate(artifact.input_n_variables, artifact.input_numerators_claim); + } + } +} + +pub fn generate_lookup_instances( + lookup_data: XorTableLookupData, + lookup_elements: &XorElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Xor interaction trace").entered(); + let limb_bits = limb_bits::(); + let col_bits = column_bits::(); + let col_size = 1 << col_bits; + let offsets_vec = u32x16::from_array(array::from_fn(|i| i as u32)); + let mut xor_lookup_layers = Vec::new(); + + // There are 2^(2*EXPAND_BITS) columns, for each combination of ah, bh. + for (i, mults) in lookup_data.xor_accum.mults.iter().enumerate() { + let numerators = Mle::::new(mults.clone()); + let mut denominators = Mle::::new(SecureColumn::zeros(col_size)); + + // Extract ah, bh from column index. + let ah = i as u32 >> EXPAND_BITS; + let bh = i as u32 & ((1 << EXPAND_BITS) - 1); + + // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. + #[allow(clippy::needless_range_loop)] + for vec_row in 0..1 << (col_bits - LOG_N_LANES) { + // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. + // Extract al, blh from vec_row. + let al = vec_row >> (limb_bits - LOG_N_LANES); + let blh = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1); + + // Construct the 3 vectors a, b, c. + let a = u32x16::splat((ah << limb_bits) | al); + // bll is just the consecutive numbers 0..N_LANES-1. + let b = u32x16::splat((bh << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec; + let c = a ^ b; + + let denom = lookup_elements + .combine(&[a, b, c].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) })); + denominators.data[vec_row as usize] = denom; + } + + collection_for_univariate_iop.push(numerators.clone()); + xor_lookup_layers.push(Layer::LogUpMultiplicities { + numerators, + denominators, + }); + } + + xor_lookup_layers +} diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index c5e3a4eda3..d4beaabe27 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,4 +1,4 @@ pub mod blake; +pub mod blake_gkr; pub mod plonk; pub mod poseidon; -pub mod xor; diff --git a/crates/prover/src/examples/xor/gkr_lookups/mod.rs b/crates/prover/src/examples/xor/gkr_lookups/mod.rs deleted file mode 100644 index 6ee603eb03..0000000000 --- a/crates/prover/src/examples/xor/gkr_lookups/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod accumulation; -pub mod mle_eval; diff --git a/crates/prover/src/examples/xor/mod.rs b/crates/prover/src/examples/xor/mod.rs deleted file mode 100644 index 34e702a9b1..0000000000 --- a/crates/prover/src/examples/xor/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod gkr_lookups;