diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 4c7f75768..3066b8e66 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -5,7 +5,7 @@ use std::ops::Deref; use itertools::Itertools; use tracing::{span, Level}; -use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; +use super::{EvalAtRow, EvalAtRowWithMle, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; @@ -57,6 +57,18 @@ impl TraceLocationAllocator { } } +/// A component defined solely in means of the constraints framework. +/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for +/// the SIMD backend. +/// Note that the constraint framework only support components with columns of the same size. +pub trait FrameworkEvalWithMle { + fn log_size(&self) -> u32; + + fn max_constraint_log_degree_bound(&self) -> u32; + + fn evaluate(&self, eval: E) -> E; +} + /// A component defined solely in means of the constraints framework. /// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for /// the SIMD backend. diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index a608d89b0..ef30033b6 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -107,7 +107,7 @@ impl LookupElements { } pub fn combine(&self, values: &[F]) -> EF where - EF: Copy + Zero + From + From + Mul + Sub, + EF: Copy + Zero + From + From + Mul + Sub, { zip_eq(values, self.alpha_powers).fold(EF::zero(), |acc, (&value, power)| { acc + EF::from(power) * value diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 87069d344..db7a1d3e8 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -95,3 +95,7 @@ pub trait EvalAtRow { /// Combines 4 base field values into a single extension field value. fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; } + +trait EvalAtRowWithMle: EvalAtRow { + fn add_mle_coeff_col_eval(&mut self, eval: Self::EF); +} diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index 6c6f72f81..a91b8ff7b 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -55,3 +55,56 @@ impl<'a> EvalAtRow for PointEvaluator<'a> { SecureField::from_partial_evals(values) } } + +// /// Evaluates expressions at a point out of domain. +// pub struct MleCoeffColEvalAccumulator<'a> { +// pub mask: TreeVec>>, +// pub evaluation_accumulator: &'a mut PointEvaluationAccumulator, +// pub col_index: Vec, +// pub denom_inverse: SecureField, +// } +// impl<'a> MleCoeffColEvalAccumulator<'a> { +// pub fn new( +// mask: TreeVec>>, +// evaluation_accumulator: &'a mut PointEvaluationAccumulator, +// denom_inverse: SecureField, +// ) -> Self { +// let col_index = vec![0; mask.len()]; +// Self { +// mask, +// evaluation_accumulator, +// col_index, +// denom_inverse, +// } +// } +// } +// impl<'a> EvalAtRow for MleCoeffColEvalAccumulator<'a> { +// type F = SecureField; +// type EF = SecureField; + +// fn next_interaction_mask( +// &mut self, +// interaction: usize, +// _offsets: [isize; N], +// ) -> [Self::F; N] { +// let col_index = self.col_index[interaction]; +// self.col_index[interaction] += 1; +// let mask = self.mask[interaction][col_index].clone(); +// assert_eq!(mask.len(), N); +// mask.try_into().unwrap() +// } +// fn add_constraint(&mut self, constraint: G) +// where +// Self::EF: Mul, +// { +// } +// fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { +// SecureField::from_partial_evals(values) +// } +// } + +// impl<'a> EvalAtRowWithMle for MleCoeffColEvalAccumulator<'a> { +// fn add_mle_coeff_col_eval(&mut self, eval: Self::EF) { +// todo!() +// } +// } diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index ae9ab6b65..cd3f5937e 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -326,6 +326,7 @@ mod tests { let GkrArtifact { ood_point: r, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; @@ -354,6 +355,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -391,6 +393,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -427,6 +430,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 017948dee..7ee7b268e 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -559,6 +559,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; @@ -590,6 +591,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -629,6 +631,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; @@ -666,6 +669,7 @@ mod tests { let GkrArtifact { ood_point, claims_to_verify_by_instance, + gate_by_instance: _, n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 6e6ed2586..d3d792ccb 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use num_traits::{One, Zero}; use thiserror::Error; -use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask}; +use super::gkr_verifier::{Gate, GkrArtifact, GkrBatchProof, GkrMask}; use super::mle::{Mle, MleOps}; use super::sumcheck::MultivariatePolyOracle; use super::utils::{eq, random_linear_combination, UnivariatePoly}; @@ -409,6 +409,16 @@ pub fn prove_batch( .collect_vec(); let n_layers = *n_layers_by_instance.iter().max().unwrap(); + let gate_by_instance = input_layer_by_instance + .iter() + .map(|l| match l { + Layer::GrandProduct(_) => Gate::GrandProduct, + Layer::LogUpGeneric { .. } + | Layer::LogUpMultiplicities { .. } + | Layer::LogUpSingles { .. } => Gate::LogUp, + }) + .collect(); + // Evaluate all instance circuits and collect the layer values. let mut layers_by_instance = input_layer_by_instance .into_iter() @@ -502,6 +512,7 @@ pub fn prove_batch( let artifact = GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance: n_layers_by_instance, }; diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index b65ceb162..4d17c15dc 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -143,6 +143,7 @@ pub fn partially_verify_batch( Ok(GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(), }) @@ -162,12 +163,114 @@ pub struct GkrBatchProof { pub struct GkrArtifact { /// Out-of-domain (OOD) point for evaluating columns in the input layer. pub ood_point: Vec, + /// The gate of each instance. + pub gate_by_instance: Vec, /// The claimed evaluation at `ood_point` for each column in the input layer of each instance. pub claims_to_verify_by_instance: Vec>, /// The number of variables that interpolate the input layer of each instance. pub n_variables_by_instance: Vec, } +impl GkrArtifact { + pub fn ood_point(&self, instance_n_variables: usize) -> &[SecureField] { + &self.ood_point[self.ood_point.len() - instance_n_variables..] + } +} + +pub struct LookupArtifactInstanceIter<'proof, 'artifact> { + instance: usize, + gkr_proof: &'proof GkrBatchProof, + gkr_artifact: &'artifact GkrArtifact, +} + +impl<'proof, 'artifact> LookupArtifactInstanceIter<'proof, 'artifact> { + pub fn new(gkr_proof: &'proof GkrBatchProof, gkr_artifact: &'artifact GkrArtifact) -> Self { + Self { + instance: 0, + gkr_proof, + gkr_artifact, + } + } +} + +impl<'proof, 'artifact> Iterator for LookupArtifactInstanceIter<'proof, 'artifact> { + type Item = LookupArtifactInstance; + + fn next(&mut self) -> Option { + if self.instance >= self.gkr_proof.output_claims_by_instance.len() { + return None; + } + + let instance = self.instance; + let input_n_variables = self.gkr_artifact.n_variables_by_instance[instance]; + let eval_point = self.gkr_artifact.ood_point(input_n_variables).to_vec(); + let output_claim = &*self.gkr_proof.output_claims_by_instance[instance]; + let input_claims = &*self.gkr_artifact.claims_to_verify_by_instance[instance]; + let gate = self.gkr_artifact.gate_by_instance[instance]; + + let res = Some(match gate { + Gate::LogUp => { + let [numerator, denominator] = output_claim.try_into().unwrap(); + let claimed_sum = Fraction::new(numerator, denominator); + let [input_numerators_claim, input_denominators_claim] = + input_claims.try_into().unwrap(); + + LookupArtifactInstance::LogUp(LogUpArtifactInstance { + eval_point, + input_n_variables, + input_numerators_claim, + input_denominators_claim, + claimed_sum, + }) + } + Gate::GrandProduct => { + let [claimed_product] = output_claim.try_into().unwrap(); + let [input_claim] = input_claims.try_into().unwrap(); + + LookupArtifactInstance::GrandProduct(GrandProductArtifactInstance { + eval_point, + input_n_variables, + input_claim, + claimed_product, + }) + } + }); + + self.instance += 1; + res + } +} + +// TODO: Consider making the GKR artifact just a Vec. +pub enum LookupArtifactInstance { + GrandProduct(GrandProductArtifactInstance), + LogUp(LogUpArtifactInstance), +} + +pub struct GrandProductArtifactInstance { + /// GKR input layer eval point. + pub eval_point: Vec, + /// Number of variables the MLE in the GKR input layer had. + pub input_n_variables: usize, + /// Claimed input MLE evaluation at `eval_point`. + pub input_claim: SecureField, + /// Output claim from the circuit. + pub claimed_product: SecureField, +} + +pub struct LogUpArtifactInstance { + /// GKR input layer eval point. + pub eval_point: Vec, + /// Number of variables the MLEs in the GKR input layer had. + pub input_n_variables: usize, + /// Claimed input numerators MLE evaluation at `eval_point`. + pub input_numerators_claim: SecureField, + /// Claimed input denominators MLE evaluation at `eval_point`. + pub input_denominators_claim: SecureField, + /// Output claim from the circuit. + pub claimed_sum: Fraction, +} + /// Defines how a circuit operates locally on two input rows to produce a single output row. /// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure. /// @@ -176,7 +279,7 @@ pub struct GkrArtifact { /// circuit) GKR prover implementations. /// /// [Thaler13]: https://eprint.iacr.org/2013/351.pdf -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Gate { LogUp, GrandProduct, @@ -305,11 +408,13 @@ mod tests { let GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance, } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]); + assert_eq!(gate_by_instance, [Gate::GrandProduct, Gate::GrandProduct]); assert_eq!(proof.output_claims_by_instance.len(), 2); assert_eq!(claims_to_verify_by_instance.len(), 2); assert_eq!(proof.output_claims_by_instance[0], &[product0]); @@ -338,11 +443,13 @@ mod tests { let GkrArtifact { ood_point, + gate_by_instance, claims_to_verify_by_instance, n_variables_by_instance, } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]); + assert_eq!(gate_by_instance, [Gate::GrandProduct, Gate::GrandProduct]); assert_eq!(proof.output_claims_by_instance.len(), 2); assert_eq!(claims_to_verify_by_instance.len(), 2); assert_eq!(proof.output_claims_by_instance[0], &[product0]); diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index ca583abe3..a351220d1 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -60,9 +60,9 @@ impl BlakeStatement0 { } pub struct AllElements { - blake_elements: BlakeElements, - round_elements: RoundElements, - xor_elements: BlakeXorElements, + pub blake_elements: BlakeElements, + pub round_elements: RoundElements, + pub xor_elements: BlakeXorElements, } impl AllElements { pub fn draw(channel: &mut impl Channel) -> Self { @@ -222,7 +222,7 @@ where { assert!(log_size >= LOG_N_LANES); assert_eq!( - ROUND_LOG_SPLIT.map(|x| (1 << x)).into_iter().sum::() as usize, + ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::(), N_ROUNDS ); @@ -239,7 +239,7 @@ where span.exit(); // Prepare inputs. - let blake_inputs = (0..(1 << (log_size - LOG_N_LANES))) + let blake_inputs = (0..1 << (log_size - LOG_N_LANES)) .map(|i| { let v = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j) as u32)); 16]; let m = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; @@ -281,18 +281,15 @@ where // Trace commitment. let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( - chain![ - scheduler_trace, - round_traces.into_iter().flatten(), - xor_trace12, - xor_trace9, - xor_trace8, - xor_trace7, - xor_trace4, - ] - .collect_vec(), - ); + tree_builder.extend_evals(chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ]); tree_builder.commit(channel); span.exit(); diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index 6fbe6d81b..577d7acf5 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -12,28 +12,28 @@ use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::FieldExpOps; -mod air; -mod round; -mod scheduler; -mod xor_table; +pub mod air; +pub mod round; +pub mod scheduler; +pub mod xor_table; -const STATE_SIZE: usize = 16; -const MESSAGE_SIZE: usize = 16; -const N_FELTS_IN_U32: usize = 2; -const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; +pub const STATE_SIZE: usize = 16; +pub const MESSAGE_SIZE: usize = 16; +pub const N_FELTS_IN_U32: usize = 2; +pub const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; // Parameters for Blake2s. Change these for blake3. -const N_ROUNDS: usize = 10; +pub const N_ROUNDS: usize = 10; /// A splitting N_ROUNDS into several powers of 2. -const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; +pub const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; #[derive(Default)] -struct XorAccums { - xor12: XorAccumulator<12, 4>, - xor9: XorAccumulator<9, 2>, - xor8: XorAccumulator<8, 2>, - xor7: XorAccumulator<7, 2>, - xor4: XorAccumulator<4, 0>, +pub struct XorAccums { + pub xor12: XorAccumulator<12, 4>, + pub xor9: XorAccumulator<9, 2>, + pub xor8: XorAccumulator<8, 2>, + pub xor7: XorAccumulator<7, 2>, + pub xor4: XorAccumulator<4, 0>, } impl XorAccums { fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) { @@ -50,11 +50,11 @@ impl XorAccums { #[derive(Clone)] pub struct BlakeXorElements { - xor12: XorElements, - xor9: XorElements, - xor8: XorElements, - xor7: XorElements, - xor4: XorElements, + pub xor12: XorElements, + pub xor9: XorElements, + pub xor8: XorElements, + pub xor7: XorElements, + pub xor4: XorElements, } impl BlakeXorElements { fn draw(channel: &mut impl Channel) -> Self { @@ -75,7 +75,7 @@ impl BlakeXorElements { xor4: XorElements::dummy(), } } - fn get(&self, w: u32) -> &XorElements { + pub fn get(&self, w: u32) -> &XorElements { match w { 12 => &self.xor12, 9 => &self.xor9, @@ -89,7 +89,7 @@ impl BlakeXorElements { /// Utility for representing a u32 as two field elements, for constraint evaluation. #[derive(Clone, Copy, Debug)] -struct Fu32 +pub struct Fu32 where F: FieldExpOps + Copy @@ -99,8 +99,8 @@ where + Sub + Mul, { - l: F, - h: F, + pub l: F, + pub h: F, } impl Fu32 where @@ -112,7 +112,7 @@ where + Sub + Mul, { - fn to_felts(self) -> [F; 2] { + pub fn to_felts(self) -> [F; 2] { [self.l, self.h] } } diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index 7adbe6fd5..a5f458b03 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -23,19 +23,19 @@ use crate::examples::blake::{to_felts, XorAccums, N_ROUND_INPUT_FELTS, STATE_SIZ pub struct BlakeRoundLookupData { /// A vector of (w, [a_col, b_col, c_col]) for each xor lookup. /// w is the xor width. c_col is the xor col of a_col and b_col. - xor_lookups: Vec<(u32, [BaseColumn; 3])>, + pub xor_lookups: Vec<(u32, [BaseColumn; 3])>, /// A column of round lookup values (v_in, v_out, m). - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], + pub round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], } pub struct TraceGenerator { - log_size: u32, - trace: Vec, - xor_lookups: Vec<(u32, [BaseColumn; 3])>, - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], + pub log_size: u32, + pub trace: Vec, + pub xor_lookups: Vec<(u32, [BaseColumn; 3])>, + pub round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], } impl TraceGenerator { - fn new(log_size: u32) -> Self { + pub fn new(log_size: u32) -> Self { assert!(log_size >= LOG_N_LANES); let trace = (0..blake_round_info().mask_offsets[0].len()) .map(|_| unsafe { Col::::uninitialized(1 << log_size) }) @@ -50,7 +50,7 @@ impl TraceGenerator { } } - fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { + pub fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { TraceGeneratorRow { gen: self, col_index: 0, @@ -61,7 +61,7 @@ impl TraceGenerator { } /// Trace generator for the constraints defined at [`super::constraints::BlakeRoundEval`] -struct TraceGeneratorRow<'a> { +pub struct TraceGeneratorRow<'a> { gen: &'a mut TraceGenerator, col_index: usize, vec_row: usize, @@ -79,7 +79,7 @@ impl<'a> TraceGeneratorRow<'a> { self.append_felt(val >> 16); } - fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { + pub fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { let input_v = v; v.iter().for_each(|s| { self.append_u32(*s); @@ -215,7 +215,7 @@ pub fn generate_trace( let _span = span!(Level::INFO, "Round Generation").entered(); let mut generator = TraceGenerator::new(log_size); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let mut row_gen = generator.gen_row(vec_row); let BlakeRoundInput { v, m } = inputs.get(vec_row).copied().unwrap_or_default(); row_gen.generate(v, m); diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index cf8311339..49d798f4b 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -1,7 +1,10 @@ mod constraints; mod gen; -pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; +pub use gen::{ + generate_interaction_trace, generate_trace, BlakeRoundInput, BlakeRoundLookupData, + TraceGenerator, TraceGeneratorRow, +}; use num_traits::Zero; use super::{BlakeXorElements, N_ROUND_INPUT_FELTS}; diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index ee9a1c654..6b45f08d4 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -61,7 +61,7 @@ pub fn eval_blake_scheduler_constraints( logup.finalize(eval); } -fn eval_next_u32(eval: &mut E) -> Fu32 { +pub fn eval_next_u32(eval: &mut E) -> Fu32 { let l = eval.next_trace_mask(); let h = eval.next_trace_mask(); Fu32 { l, h } diff --git a/crates/prover/src/examples/blake/scheduler/gen.rs b/crates/prover/src/examples/blake/scheduler/gen.rs index cd6a99b2f..ae3569ed0 100644 --- a/crates/prover/src/examples/blake/scheduler/gen.rs +++ b/crates/prover/src/examples/blake/scheduler/gen.rs @@ -58,7 +58,7 @@ pub fn gen_trace( .map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) .collect_vec(); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let mut col_index = 0; let mut write_u32_array = |x: [u32x16; STATE_SIZE], col_index: &mut usize| { @@ -125,11 +125,11 @@ pub fn gen_interaction_trace( let mut logup_gen = LogupTraceGenerator::new(log_size); - for [l0, l1] in lookup_data.round_lookups.array_chunks::<2>() { + for [l0, l1] in lookup_data.round_lookups.array_chunks() { let mut col_gen = logup_gen.new_col(); #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let p0: PackedSecureField = round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); let p1: PackedSecureField = @@ -145,7 +145,7 @@ pub fn gen_interaction_trace( // with the entire blake lookup. let mut col_gen = logup_gen.new_col(); #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + for vec_row in 0..1 << (log_size - LOG_N_LANES) { let p_blake: PackedSecureField = blake_lookup_elements.combine( &lookup_data .blake_lookups diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index e8a8c32f3..3035626c8 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -2,7 +2,8 @@ mod constraints; mod gen; use constraints::eval_blake_scheduler_constraints; -pub use gen::{gen_interaction_trace, gen_trace, BlakeInput}; +pub use constraints::eval_next_u32; +pub use gen::{gen_interaction_trace, gen_trace, BlakeInput, BlakeSchedulerLookupData}; use num_traits::Zero; use super::round::RoundElements; diff --git a/crates/prover/src/examples/blake/xor_table/gen.rs b/crates/prover/src/examples/blake/xor_table/gen.rs index 195a6ca46..46309e640 100644 --- a/crates/prover/src/examples/blake/xor_table/gen.rs +++ b/crates/prover/src/examples/blake/xor_table/gen.rs @@ -74,7 +74,7 @@ pub fn generate_interaction_trace( // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (column_bits::() - LOG_N_LANES)) { + for vec_row in 0..1 << (column_bits::() - LOG_N_LANES) { // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. // Extract al, blh from vec_row. let al = vec_row >> (limb_bits - LOG_N_LANES); diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 877a65114..b796e702e 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -17,7 +17,9 @@ use std::simd::u32x16; use itertools::Itertools; use num_traits::Zero; -pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; +pub use r#gen::{ + generate_constant_trace, generate_interaction_trace, generate_trace, XorTableLookupData, +}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; @@ -37,7 +39,7 @@ pub fn trace_sizes() -> TreeVec()) } -const fn limb_bits() -> u32 { +pub const fn limb_bits() -> u32 { ELEM_BITS - EXPAND_BITS } pub const fn column_bits() -> u32 { diff --git a/crates/prover/src/examples/blake_gkr/air.rs b/crates/prover/src/examples/blake_gkr/air.rs new file mode 100644 index 000000000..7ac7bfc30 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/air.rs @@ -0,0 +1,526 @@ +use std::array; +use std::simd::u32x16; + +use itertools::{chain, multiunzip, Itertools}; +use tracing::{span, Level}; + +use super::gkr_lookups::MleCoeffColumnOracleAccumulator; +use super::round::{BlakeRoundComponent, BlakeRoundEval}; +use super::scheduler::BlakeSchedulerComponent; +use super::xor_table::{XorLookupArtifacts, XorTableComponent, XorTableEval}; +use crate::constraint_framework::{FrameworkEval, TraceLocationAllocator}; +use crate::core::air::ComponentProver; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::BackendForChannel; +use crate::core::channel::{Channel, MerkleChannel}; +use crate::core::lookups::gkr_prover::prove_batch; +use crate::core::lookups::gkr_verifier::{ + GkrBatchProof, LookupArtifactInstance, LookupArtifactInstanceIter, +}; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; +use crate::core::poly::circle::{CanonicCoset, PolyOps}; +use crate::core::prover::{prove, StarkProof, VerificationError}; +use crate::core::vcs::ops::MerkleHasher; +use crate::examples::blake::air::AllElements; +use crate::examples::blake::scheduler::{self as air_scheduler, BlakeInput}; +use crate::examples::blake::{ + round as air_round, xor_table as air_xor_table, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, +}; +use crate::examples::blake_gkr::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use crate::examples::blake_gkr::gkr_lookups::mle_eval::{self, MleEvalProverComponent}; +use crate::examples::blake_gkr::round::RoundLookupArtifact; +use crate::examples::blake_gkr::scheduler::{BlakeSchedulerEval, SchedulerLookupArtifact}; +use crate::examples::blake_gkr::{round, scheduler, xor_table}; + +pub struct BlakeClaim { + log_size: u32, +} + +impl BlakeClaim { + fn mix_into(&self, channel: &mut impl Channel) { + // TODO(spapini): Do this better. + channel.mix_nonce(self.log_size as u64); + } +} + +pub struct BlakeProof { + pub claim: BlakeClaim, + pub gkr_proof: GkrBatchProof, + pub stark_proof: StarkProof, +} + +pub struct BlakeLookupArtifacts { + scheduler: SchedulerLookupArtifact, + /// `|ROUND_LOG_SPLIT|` many round artifacts. + rounds: Vec, + xor: XorLookupArtifacts, +} + +impl BlakeLookupArtifacts { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + scheduler: SchedulerLookupArtifact::new_from_iter(&mut iter), + rounds: ROUND_LOG_SPLIT + .iter() + .map(|_| RoundLookupArtifact::new_from_iter(&mut iter)) + .collect(), + xor: XorLookupArtifacts::new_from_iter(&mut iter), + } + } + + pub fn verify_succinct_mle_claims( + &self, + lookup_elements: &AllElements, + ) -> Result<(), InvalidClaimError> { + let Self { + scheduler, + rounds, + xor, + } = self; + scheduler.verify_succinct_mle_claims()?; + for round in rounds { + round.verify_succinct_mle_claims()?; + } + xor.verify_succinct_mle_claims(&lookup_elements.xor_elements)?; + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { + scheduler, + rounds, + xor, + } = self; + scheduler.accumulate_mle_eval_iop_claims(acc); + rounds + .iter() + .for_each(|round| round.accumulate_mle_eval_iop_claims(acc)); + xor.accumulate_mle_eval_iop_claims(acc); + } +} + +#[derive(Debug)] +pub struct InvalidClaimError; + +pub struct BlakeComponents { + scheduler_component: BlakeSchedulerComponent, + round_components: Vec, + xor12: XorTableComponent<12, 4>, + xor9: XorTableComponent<9, 2>, + xor8: XorTableComponent<8, 2>, + xor7: XorTableComponent<7, 2>, + xor4: XorTableComponent<4, 0>, +} + +impl BlakeComponents { + pub fn new( + trace_location_allocator: &mut TraceLocationAllocator, + claim: &BlakeClaim, + all_elements: &AllElements, + ) -> Self { + Self { + scheduler_component: BlakeSchedulerComponent::new( + trace_location_allocator, + BlakeSchedulerEval { + log_size: claim.log_size, + blake_lookup_elements: all_elements.blake_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + }, + ), + round_components: ROUND_LOG_SPLIT + .iter() + .map(|l| { + BlakeRoundComponent::new( + trace_location_allocator, + BlakeRoundEval { + log_size: claim.log_size + l, + xor_lookup_elements: all_elements.xor_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + }, + ) + }) + .collect(), + xor12: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + xor9: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + xor8: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + xor7: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + xor4: XorTableComponent::new( + trace_location_allocator, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + }, + ), + } + } + + pub fn accumulate_mle_coeff_col_oracles<'this: 'acc, 'acc>( + &'this self, + acc_by_n_vars: &mut [Option>], + ) { + let Self { + scheduler_component, + round_components, + xor12, + xor9, + xor8, + xor7, + xor4, + } = self; + acc_by_n_vars[scheduler_component.log_size as usize] + .as_mut() + .unwrap() + .accumulate(scheduler_component); + for round_component in round_components { + acc_by_n_vars[round_component.log_size as usize] + .as_mut() + .unwrap() + .accumulate(round_component) + } + acc_by_n_vars[xor12.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor12); + acc_by_n_vars[xor9.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor9); + acc_by_n_vars[xor8.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor8); + acc_by_n_vars[xor7.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor7); + acc_by_n_vars[xor4.log_size() as usize] + .as_mut() + .unwrap() + .accumulate(xor4); + } + + fn component_provers(&self) -> Vec<&dyn ComponentProver> { + chain![ + [&self.scheduler_component as &dyn ComponentProver], + self.round_components + .iter() + .map(|c| c as &dyn ComponentProver), + [ + &self.xor12 as &dyn ComponentProver, + &self.xor9 as &dyn ComponentProver, + &self.xor8 as &dyn ComponentProver, + &self.xor7 as &dyn ComponentProver, + &self.xor4 as &dyn ComponentProver, + ] + ] + .collect() + } +} + +pub fn prove_blake(log_size: u32, config: PcsConfig) -> BlakeProof +where + SimdBackend: BackendForChannel, +{ + assert!(log_size >= LOG_N_LANES); + assert_eq!( + ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::(), + N_ROUNDS + ); + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + const XOR_TABLE_MAX_LOG_SIZE: u32 = 16; + let max_log_size = + (log_size + *ROUND_LOG_SPLIT.iter().max().unwrap()).max(XOR_TABLE_MAX_LOG_SIZE); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(max_log_size + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Prepare inputs. + let blake_inputs = (0..1 << (log_size - LOG_N_LANES)) + .map(|i| { + let v = [u32x16::from_array(array::from_fn(|j| (i + 2 * j) as u32)); 16]; + let m = [u32x16::from_array(array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; + BlakeInput { v, m } + }) + .collect_vec(); + + // Setup protocol. + let channel = &mut MC::C::default(); + let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); + + let span = span!(Level::INFO, "Trace").entered(); + + // Scheduler. + let (scheduler_trace, scheduler_lookup_data, round_inputs) = + air_scheduler::gen_trace(log_size, &blake_inputs); + + // Rounds. + let mut xor_accums = XorAccums::default(); + let mut rest = &round_inputs[..]; + // Split round inputs to components, according to [ROUND_LOG_SPLIT]. + let (round_traces, round_lookup_datas): (Vec<_>, Vec<_>) = + multiunzip(ROUND_LOG_SPLIT.map(|l| { + let (cur_inputs, r) = rest.split_at(1 << (log_size - LOG_N_LANES + l)); + rest = r; + air_round::generate_trace(log_size + l, cur_inputs, &mut xor_accums) + })); + + // Xor tables. + let (xor_trace12, xor_lookup_data12) = air_xor_table::generate_trace(xor_accums.xor12); + let (xor_trace9, xor_lookup_data9) = air_xor_table::generate_trace(xor_accums.xor9); + let (xor_trace8, xor_lookup_data8) = air_xor_table::generate_trace(xor_accums.xor8); + let (xor_trace7, xor_lookup_data7) = air_xor_table::generate_trace(xor_accums.xor7); + let (xor_trace4, xor_lookup_data4) = air_xor_table::generate_trace(xor_accums.xor4); + + // Claim. + let claim = BlakeClaim { log_size }; + claim.mix_into(channel); + + // Trace commitment. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ]); + tree_builder.commit(channel); + span.exit(); + + // Draw lookup element. + let all_elements = AllElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let mut lookup_input_layers = Vec::new(); + let mut mle_eval_at_point_collection = MleCollection::default(); + + lookup_input_layers.extend(scheduler::generate_lookup_instances( + log_size, + scheduler_lookup_data, + &all_elements.round_elements, + &all_elements.blake_elements, + &mut mle_eval_at_point_collection, + )); + + ROUND_LOG_SPLIT + .iter() + .zip(round_lookup_datas) + .for_each(|(l, lookup_data)| { + lookup_input_layers.extend(round::generate_lookup_instances( + log_size + l, + lookup_data, + &all_elements.xor_elements, + &all_elements.round_elements, + &mut mle_eval_at_point_collection, + )); + }); + + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data12, + &all_elements.xor_elements.xor12, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data9, + &all_elements.xor_elements.xor9, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data8, + &all_elements.xor_elements.xor8, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data7, + &all_elements.xor_elements.xor7, + &mut mle_eval_at_point_collection, + )); + lookup_input_layers.extend(xor_table::generate_lookup_instances( + xor_lookup_data4, + &all_elements.xor_elements.xor4, + &mut mle_eval_at_point_collection, + )); + + let gkr_span = span!(Level::INFO, "GKR proof").entered(); + let (gkr_proof, gkr_artifact) = prove_batch(channel, lookup_input_layers); + gkr_span.exit(); + let mle_acc_coeff = channel.draw_felt(); + let mles = mle_eval_at_point_collection.random_linear_combine_by_n_variables(mle_acc_coeff); + + // TODO(andrew): Consider unifying new_from_iter, verify_succinct_mle_claims, + // accumulate_mle_eval_iop_claims. + let mut lookup_instances_iter = LookupArtifactInstanceIter::new(&gkr_proof, &gkr_artifact); + let blake_lookup_artifacts = BlakeLookupArtifacts::new_from_iter(&mut lookup_instances_iter); + assert!(lookup_instances_iter.next().is_none()); + blake_lookup_artifacts + .verify_succinct_mle_claims(&all_elements) + .unwrap(); + let mut mle_eval_iop_acc = MleClaimAccumulator::new(mle_acc_coeff); + blake_lookup_artifacts.accumulate_mle_eval_iop_claims(&mut mle_eval_iop_acc); + let mut mle_claim_by_n_variables = mle_eval_iop_acc.finalize(); + + let max_mle_n_variables = mles.iter().map(|mle| mle.n_variables()).max().unwrap(); + let mut mle_coeff_col_acc_by_n_variables = vec![None; max_mle_n_variables + 1]; + + for mle in &mles { + let n_variables = mle.n_variables(); + mle_coeff_col_acc_by_n_variables[n_variables] = + Some(MleCoeffColumnOracleAccumulator::new(mle_acc_coeff)); + } + + let trace_location_allocator = &mut TraceLocationAllocator::default(); + let blake_components = BlakeComponents::new(trace_location_allocator, &claim, &all_elements); + blake_components.accumulate_mle_coeff_col_oracles(&mut mle_coeff_col_acc_by_n_variables); + + let mut tree_builder = commitment_scheme.tree_builder(); + let mle_eval_prover_components = mles + .into_iter() + .map(|mle| { + let n_vars = mle.n_variables(); + let coeff_column_oracle = mle_coeff_col_acc_by_n_variables[n_vars].as_ref().unwrap(); + let claim = mle_claim_by_n_variables[n_vars].take().unwrap(); + let eval_point = gkr_artifact.ood_point(n_vars); + + tree_builder.extend_evals(mle_eval::build_trace(&mle, eval_point, claim)); + + // Sanity check the claims. + #[cfg(test)] + debug_assert_eq!(claim, mle.eval_at_point(eval_point)); + + MleEvalProverComponent::generate( + trace_location_allocator, + coeff_column_oracle, + eval_point, + mle, + claim, + &twiddles, + 1, + ) + }) + .collect_vec(); + tree_builder.commit(channel); + span.exit(); + + let components = chain![ + blake_components.component_provers(), + mle_eval_prover_components + .iter() + .map(|c| c as &dyn ComponentProver) + ] + .collect_vec(); + + let stark_proof = prove(&components, channel, commitment_scheme).unwrap(); + + BlakeProof { + claim, + gkr_proof, + stark_proof, + } +} + +#[allow(unused)] +pub fn verify_blake( + BlakeProof { + claim, + gkr_proof, + stark_proof, + }: BlakeProof, + config: PcsConfig, +) -> Result<(), VerificationError> { + let channel = &mut MC::C::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + + // let log_sizes = stmt0.log_sizes(); + + // // Trace. + // stmt0.mix_into(channel); + // commitment_scheme.commit(stark_proof.commitments[0], &log_sizes[0], channel); + + // // Draw interaction elements. + // let all_elements = AllElements::draw(channel); + + // // Interaction trace. + // stmt1.mix_into(channel); + // commitment_scheme.commit(stark_proof.commitments[1], &log_sizes[1], channel); + + // // Constant trace. + // commitment_scheme.commit(stark_proof.commitments[2], &log_sizes[2], channel); + + // let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); + + // // Check that all sums are correct. + // let total_sum = stmt1.scheduler_claimed_sum + // + stmt1.round_claimed_sums.iter().sum::() + // + stmt1.xor12_claimed_sum + // + stmt1.xor9_claimed_sum + // + stmt1.xor8_claimed_sum + // + stmt1.xor7_claimed_sum + // + stmt1.xor4_claimed_sum; + + // // TODO(spapini): Add inputs to sum, and constraint them. + // assert_eq!(total_sum, SecureField::zero()); + + // verify( + // &components.components(), + // channel, + // commitment_scheme, + // stark_proof, + // ) + + todo!() +} + +#[cfg(test)] +mod tests { + use std::env; + + use crate::core::pcs::PcsConfig; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::blake_gkr::air::prove_blake; + + // Note: this test is slow. Only run in release. + #[cfg_attr(not(feature = "slow-tests"), ignore)] + #[test_log::test] + fn test_simd_blake_gkr_prove() { + // Get from environment variable: + let log_n_instances = env::var("LOG_N_INSTANCES") + .unwrap_or_else(|_| "6".to_string()) + .parse::() + .unwrap(); + let config = PcsConfig::default(); + + // Prove. + let _proof = prove_blake::(log_n_instances, config); + + // Verify. + // verify_blake::(proof, config).unwrap(); + } +} diff --git a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs similarity index 84% rename from crates/prover/src/examples/xor/gkr_lookups/accumulation.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs index 986289572..d3c49aed9 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/accumulation.rs @@ -2,7 +2,7 @@ use std::iter::zip; use std::ops::{AddAssign, Mul}; use educe::Educe; -use num_traits::One; +use num_traits::{One, Zero}; use crate::core::backend::simd::SimdBackend; use crate::core::backend::Backend; @@ -16,7 +16,7 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; /// Max number of variables for multilinear polynomials that get compiled into a univariate /// IOP for multilinear eval at point. -pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; +pub const MAX_MLE_N_VARIABLES: usize = (M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR) as usize; /// Collection of [`Mle`]s grouped by their number of variables. pub struct MleCollection { @@ -90,7 +90,7 @@ pub fn combine + Copy, F: Copy>( impl Default for MleCollection { fn default() -> Self { Self { - mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES as usize + 1], + mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES + 1], } } } @@ -133,6 +133,32 @@ impl DynMle { } } +/// Accumulates claims of multilinear polynomials, grouped by their number of variables. +// TODO(andrew): Consider group by eval point to make sure everything done correctly. +pub struct MleClaimAccumulator { + acc_coeff: SecureField, + acc_by_n_variables: Vec>, +} + +impl MleClaimAccumulator { + pub fn new(acc_coeff: SecureField) -> Self { + Self { + acc_coeff, + acc_by_n_variables: vec![None; MAX_MLE_N_VARIABLES + 1], + } + } + + pub fn accumulate(&mut self, n_variables: usize, evaluation: SecureField) { + let acc = self.acc_by_n_variables[n_variables].get_or_insert_with(SecureField::zero); + *acc = *acc * self.acc_coeff + evaluation; + } + + /// Returns a mapping of number of variables to claim accumulation. + pub fn finalize(self) -> Vec> { + self.acc_by_n_variables + } +} + #[cfg(test)] mod tests { use std::iter::repeat; @@ -144,7 +170,7 @@ mod tests { use crate::core::fields::qm31::SecureField; use crate::core::fields::Field; use crate::core::lookups::mle::{Mle, MleOps}; - use crate::examples::xor::gkr_lookups::accumulation::MleCollection; + use crate::examples::blake_gkr::gkr_lookups::accumulation::MleCollection; #[test] fn random_linear_combine_by_n_variables() { diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs similarity index 96% rename from crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs rename to crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs index 9d05ce7aa..01ae85ce6 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/mle_eval.rs @@ -1,7 +1,4 @@ //! Multilinear extension (MLE) eval at point constraints. -// TODO(andrew): Remove in downstream PR. -#![allow(dead_code)] - use std::iter::zip; use itertools::{chain, zip_eq, Itertools}; @@ -650,6 +647,36 @@ fn eval_step_selector_with_offset( eval_step_selector(coset, log_step, p - offset_step.into_ef()) } +// /// Returns `log(|coset|)` evaluations where the `i`th evaluation is of a polynomial that's `1` +// /// every `2^i` coset points and `0` elsewhere on coset. +// fn eval_step_selectors_by_log_step(coset: Coset, p: CirclePoint) -> Vec +// { let res = vec![SecureField::one()]; + +// if log_step == 0 { +// return SecureField::one(); +// } + +// // Rotate the coset to have points on the `x` axis. +// let p = p - coset.initial.into_ef(); +// let mut vanish_at_log_step = (0..coset.log_size) +// .scan(p, |p, _| { +// let res = *p; +// *p = p.double(); +// Some(res.y) +// }) +// .collect_vec(); +// vanish_at_log_step.reverse(); +// let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()]; +// SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv); + +// let norm = BaseField::from(2).inverse(); + +// let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square(); +// let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::(); +// (half_coset_selector_dbl + vanish_at_log_step[0] * vanish_substep_inv_sum.double()) +// / BaseField::from(1 << (log_step + 1)) +// } + /// Evaluates a polynomial that's `1` every `2^log_step` coset points and `0` elsewhere on coset. fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint) -> SecureField { if log_step == 0 { @@ -743,8 +770,8 @@ mod tests { use crate::core::prover::{prove, verify, VerificationError}; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; - use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; + use crate::examples::blake_gkr::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR; + use crate::examples::blake_gkr::gkr_lookups::mle_eval::eval_step_selector_with_offset; #[test] fn mle_eval_prover_component() -> Result<(), VerificationError> { @@ -814,7 +841,6 @@ mod tests { const N_VARIABLES: usize = 8; const COEFFS_COL_TRACE: usize = 0; const MLE_EVAL_TRACE: usize = 1; - const CONST_TRACE: usize = 2; const LOG_EXPAND: u32 = 1; // Create the test MLE. let mut rng = SmallRng::seed_from_u64(0); @@ -1124,7 +1150,7 @@ mod tests { use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::ColumnVec; - use crate::examples::xor::gkr_lookups::mle_eval::MleCoeffColumnOracle; + use crate::examples::blake_gkr::gkr_lookups::mle_eval::MleCoeffColumnOracle; pub type MleCoeffColumnComponent = FrameworkComponent; diff --git a/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs b/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs new file mode 100644 index 000000000..ab5319db5 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/gkr_lookups/mod.rs @@ -0,0 +1,54 @@ +use mle_eval::MleCoeffColumnOracle; + +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::circle::CirclePoint; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; + +pub mod accumulation; +pub mod mle_eval; + +// TODO(andrew): Try come up with less verbose name. +pub trait AccumulatedMleCoeffColumnOracle { + fn accumulate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ); +} + +// TODO(andrew): Try come up with less verbose name. +#[derive(Clone)] +pub struct MleCoeffColumnOracleAccumulator<'a> { + acc_coeff: SecureField, + oracles: Vec<&'a dyn AccumulatedMleCoeffColumnOracle>, +} + +impl<'a> MleCoeffColumnOracleAccumulator<'a> { + pub fn new(acc_coeff: SecureField) -> Self { + Self { + acc_coeff, + oracles: Vec::new(), + } + } + + pub fn accumulate<'b: 'a>(&mut self, oracle: &'b dyn AccumulatedMleCoeffColumnOracle) { + self.oracles.push(oracle) + } +} + +impl<'a> MleCoeffColumnOracle for MleCoeffColumnOracleAccumulator<'a> { + fn evaluate_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + ) -> SecureField { + let mut acc = PointEvaluationAccumulator::new(self.acc_coeff); + for oracle in &self.oracles { + oracle.accumulate_at_point(point, mask, &mut acc); + } + acc.finalize() + } +} diff --git a/crates/prover/src/examples/blake_gkr/mod.rs b/crates/prover/src/examples/blake_gkr/mod.rs new file mode 100644 index 000000000..a1f767c6b --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/mod.rs @@ -0,0 +1,5 @@ +pub mod air; +pub mod gkr_lookups; +pub mod round; +pub mod scheduler; +pub mod xor_table; diff --git a/crates/prover/src/examples/blake_gkr/round.rs b/crates/prover/src/examples/blake_gkr/round.rs new file mode 100644 index 000000000..8e868b93f --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/round.rs @@ -0,0 +1,339 @@ +use std::array; + +use itertools::{chain, Itertools}; +use num_traits::One; +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use super::gkr_lookups::AccumulatedMleCoeffColumnOracle; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator}; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; +use crate::examples::blake::round::{BlakeRoundLookupData, RoundElements, TraceGenerator}; +use crate::examples::blake::{BlakeXorElements, Fu32, STATE_SIZE}; + +pub type BlakeRoundComponent = FrameworkComponent; + +pub struct BlakeRoundEval { + pub log_size: u32, + pub xor_lookup_elements: BlakeXorElements, + pub round_lookup_elements: RoundElements, +} + +impl FrameworkEval for BlakeRoundEval { + fn log_size(&self) -> u32 { + self.log_size + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + 1 + } + fn evaluate(&self, eval: E) -> E { + const MLE_COEFF_COL_EVAL: bool = false; + let blake_eval = BlakeRoundConstraintEval:: { + eval, + xor_lookup_elements: &self.xor_lookup_elements, + round_lookup_elements: &self.round_lookup_elements, + mle_coeff_col_evals: None, + }; + blake_eval.eval() + } +} + +impl AccumulatedMleCoeffColumnOracle for BlakeRoundComponent { + fn accumulate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ) { + // Create dummy point evaluator just to extract the value we need from the mask + let mut _accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut _accumulator, + SecureField::one(), + ); + + let mut mle_coef_col_evals = Vec::new(); + + const MLE_COEFF_COL_EVAL: bool = true; + BlakeRoundConstraintEval::<_, MLE_COEFF_COL_EVAL> { + eval, + xor_lookup_elements: &self.xor_lookup_elements, + round_lookup_elements: &self.round_lookup_elements, + mle_coeff_col_evals: Some(&mut mle_coef_col_evals), + } + .eval(); + + for eval in mle_coef_col_evals { + acc.accumulate(eval) + } + } +} + +const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15); +const TWO: BaseField = BaseField::from_u32_unchecked(2); + +pub struct BlakeRoundConstraintEval<'a, E: EvalAtRow, const MLE_COEFF_COL_EVAL: bool> { + pub eval: E, + pub xor_lookup_elements: &'a BlakeXorElements, + pub round_lookup_elements: &'a RoundElements, + pub mle_coeff_col_evals: Option<&'a mut Vec>, +} +impl<'a, E: EvalAtRow, const MLE_COEFF_COL_EVAL: bool> + BlakeRoundConstraintEval<'a, E, MLE_COEFF_COL_EVAL> +{ + pub fn eval(mut self) -> E { + let mut v: [Fu32; STATE_SIZE] = array::from_fn(|_| self.next_u32()); + let input_v = v; + let m: [Fu32; STATE_SIZE] = array::from_fn(|_| self.next_u32()); + + self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); + self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); + self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); + self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); + self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); + self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); + self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); + self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); + + if MLE_COEFF_COL_EVAL { + self.mle_coeff_col_evals.as_mut().unwrap().push( + self.round_lookup_elements.combine( + &chain![ + input_v.iter().copied().flat_map(Fu32::to_felts), + v.iter().copied().flat_map(Fu32::to_felts), + m.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + ), + ); + } + + self.eval + } + fn next_u32(&mut self) -> Fu32 { + let l = self.eval.next_trace_mask(); + let h = self.eval.next_trace_mask(); + Fu32 { l, h } + } + fn g(&mut self, v: [&mut Fu32; 4], m0: Fu32, m1: Fu32) { + let [a, b, c, d] = v; + + *a = self.add3_u32_unchecked(*a, *b, m0); + *d = self.xor_rotr16_u32(*a, *d); + *c = self.add2_u32_unchecked(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 12); + *a = self.add3_u32_unchecked(*a, *b, m1); + *d = self.xor_rotr_u32(*a, *d, 8); + *c = self.add2_u32_unchecked(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 7); + } + + /// Adds two u32s, returning the sum. + /// Assumes a, b are properly range checked. + /// The caller is responsible for checking: + /// res.{l,h} not in [2^16, 2^17) or in [-2^16,0) + fn add2_u32_unchecked(&mut self, a: Fu32, b: Fu32) -> Fu32 { + let sl = self.eval.next_trace_mask(); + let sh = self.eval.next_trace_mask(); + + let carry_l = (a.l + b.l - sl) * E::F::from(INV16); + self.eval.add_constraint(carry_l * carry_l - carry_l); + + let carry_h = (a.h + b.h + carry_l - sh) * E::F::from(INV16); + self.eval.add_constraint(carry_h * carry_h - carry_h); + + Fu32 { l: sl, h: sh } + } + + /// Adds three u32s, returning the sum. + /// Assumes a, b, c are properly range checked. + /// Caller is responsible for checking: + /// res.{l,h} not in [2^16, 3*2^16) or in [-2^17,0) + fn add3_u32_unchecked(&mut self, a: Fu32, b: Fu32, c: Fu32) -> Fu32 { + let sl = self.eval.next_trace_mask(); + let sh = self.eval.next_trace_mask(); + + let carry_l = (a.l + b.l + c.l - sl) * E::F::from(INV16); + self.eval + .add_constraint(carry_l * (carry_l - E::F::one()) * (carry_l - E::F::from(TWO))); + + let carry_h = (a.h + b.h + c.h + carry_l - sh) * E::F::from(INV16); + self.eval + .add_constraint(carry_h * (carry_h - E::F::one()) * (carry_h - E::F::from(TWO))); + + Fu32 { l: sl, h: sh } + } + + /// Splits a felt at r. + /// Caller is responsible for checking that the ranges of h * 2^r and l don't overlap. + fn split_unchecked(&mut self, a: E::F, r: u32) -> (E::F, E::F) { + let h = self.eval.next_trace_mask(); + let l = a - h * E::F::from(BaseField::from_u32_unchecked(1 << r)); + (l, h) + } + + /// Checks that a, b are in range, and computes their xor rotated right by `r` bits. + /// Guarantees that all elements are in range. + fn xor_rotr_u32(&mut self, a: Fu32, b: Fu32, r: u32) -> Fu32 { + let (all, alh) = self.split_unchecked(a.l, r); + let (ahl, ahh) = self.split_unchecked(a.h, r); + let (bll, blh) = self.split_unchecked(b.l, r); + let (bhl, bhh) = self.split_unchecked(b.h, r); + + // These also guarantee that all elements are in range. + let xorll = self.xor(r, all, bll); + let xorlh = self.xor(16 - r, alh, blh); + let xorhl = self.xor(r, ahl, bhl); + let xorhh = self.xor(16 - r, ahh, bhh); + + Fu32 { + l: xorhl * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorlh, + h: xorll * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorhh, + } + } + + /// Checks that a, b are in range, and computes their xor rotated right by 16 bits. + /// Guarantees that all elements are in range. + fn xor_rotr16_u32(&mut self, a: Fu32, b: Fu32) -> Fu32 { + let (all, alh) = self.split_unchecked(a.l, 8); + let (ahl, ahh) = self.split_unchecked(a.h, 8); + let (bll, blh) = self.split_unchecked(b.l, 8); + let (bhl, bhh) = self.split_unchecked(b.h, 8); + + // These also guarantee that all elements are in range. + let xorll = self.xor(8, all, bll); + let xorlh = self.xor(8, alh, blh); + let xorhl = self.xor(8, ahl, bhl); + let xorhh = self.xor(8, ahh, bhh); + + Fu32 { + l: xorhh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorhl, + h: xorlh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorll, + } + } + + /// Checks that a, b are in [0, 2^w) and computes their xor. + fn xor(&mut self, w: u32, a: E::F, b: E::F) -> E::F { + // TODO: Separate lookups by w. + let c = self.eval.next_trace_mask(); + + if MLE_COEFF_COL_EVAL { + let lookup_elements = self.xor_lookup_elements.get(w); + self.mle_coeff_col_evals + .as_mut() + .unwrap() + .push(lookup_elements.combine(&[a, b, c])); + } + + c + } +} + +pub struct RoundLookupArtifact { + pub round: LogUpArtifactInstance, + pub xors: Vec, +} + +impl RoundLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + let xors = (0..n_xor_lookups()) + .map(|_| match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }) + .collect(); + + let round = match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }; + + Self { round, xors } + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { round, xors } = self; + + for xor in xors { + acc.accumulate(xor.input_n_variables, xor.input_denominators_claim); + } + + acc.accumulate(round.input_n_variables, round.input_denominators_claim); + } + + pub fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + let Self { round, xors } = self; + + for xor_artifact in xors { + if !xor_artifact.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + } + + if !round.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + + Ok(()) + } +} + +/// Returns an ordered list of all XOR lookup types the round component uses. +fn n_xor_lookups() -> usize { + // Create a dummy trace to extract the structural xor lookup information. + let mut trace_generator = TraceGenerator::new(LOG_N_LANES); + let mut row = trace_generator.gen_row(0); + row.generate(Default::default(), Default::default()); + trace_generator.xor_lookups.len() +} + +pub fn generate_lookup_instances( + log_size: u32, + lookup_data: BlakeRoundLookupData, + xor_lookup_elements: &BlakeXorElements, + round_lookup_elements: &RoundElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Generate round interaction trace").entered(); + let size = 1 << log_size; + let mut round_lookup_layers = Vec::new(); + + for (w, l) in &lookup_data.xor_lookups { + let lookup_elements = xor_lookup_elements.get(*w); + let mut denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = lookup_elements.combine(&l.each_ref().map(|l| l.data[vec_row])); + denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { denominators }); + } + + // Blake round lookup. + let mut round_denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = round_lookup_elements + .combine(&lookup_data.round_lookup.each_ref().map(|l| l.data[vec_row])); + round_denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(round_denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { + denominators: round_denominators, + }); + + round_lookup_layers +} diff --git a/crates/prover/src/examples/blake_gkr/scheduler.rs b/crates/prover/src/examples/blake_gkr/scheduler.rs new file mode 100644 index 000000000..b1dd50dd3 --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/scheduler.rs @@ -0,0 +1,209 @@ +use std::array; + +use itertools::{chain, Itertools}; +use num_traits::{One, Zero}; +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use super::gkr_lookups::AccumulatedMleCoeffColumnOracle; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator}; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::backend::simd::blake2s::SIGMA; +use crate::core::backend::simd::column::{BaseColumn, SecureColumn}; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; +use crate::examples::blake::round::RoundElements; +use crate::examples::blake::scheduler::{eval_next_u32, BlakeElements, BlakeSchedulerLookupData}; +use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; + +pub type BlakeSchedulerComponent = FrameworkComponent; + +pub struct BlakeSchedulerEval { + pub log_size: u32, + pub blake_lookup_elements: BlakeElements, + pub round_lookup_elements: RoundElements, +} + +impl FrameworkEval for BlakeSchedulerEval { + fn log_size(&self) -> u32 { + self.log_size + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + } + + fn evaluate(&self, mut eval: E) -> E { + let _ = SchedulerEvals::new(&mut eval); + eval + } +} + +impl AccumulatedMleCoeffColumnOracle for BlakeSchedulerComponent { + fn accumulate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ) { + // Create dummy point evaluator just to extract the value we need from the mask + let mut _accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let mut eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut _accumulator, + SecureField::one(), + ); + + let SchedulerEvals { messages, states } = SchedulerEvals::new(&mut eval); + + // Schedule. + for i in 0..N_ROUNDS { + let input_state = &states[i]; + let output_state = &states[i + 1]; + let round_messages = SIGMA[i].map(|j| messages[j as usize]); + // Use triplet in round lookup. + let lookup_values = &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + round_messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(); + let denoms_mle_coeff_col_eval = self.round_lookup_elements.combine(lookup_values); + acc.accumulate(denoms_mle_coeff_col_eval); + } + + let input_state = &states[0]; + let output_state = &states[N_ROUNDS]; + let lookup_values = &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(); + let denoms_mle_coeff_col_eval = self.blake_lookup_elements.combine(lookup_values); + acc.accumulate(denoms_mle_coeff_col_eval); + } +} + +struct SchedulerEvals { + messages: [Fu32; STATE_SIZE], + states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1], +} + +impl SchedulerEvals { + fn new(eval: &mut E) -> Self { + Self { + messages: array::from_fn(|_| eval_next_u32(eval)), + states: array::from_fn(|_| array::from_fn(|_| eval_next_u32(eval))), + } + } +} + +pub struct SchedulerLookupArtifact { + scheduler: LogUpArtifactInstance, + rounds: [LogUpArtifactInstance; N_ROUNDS], +} + +impl SchedulerLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + let rounds = array::from_fn(|_| match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }); + + let scheduler = match iter.next() { + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }; + + Self { scheduler, rounds } + } + + pub fn verify_succinct_mle_claims(&self) -> Result<(), InvalidClaimError> { + let Self { scheduler, rounds } = self; + + // TODO(andrew): Consider checking the n_variables is correct. + // if !self.scheduler.input_numerators_claim.is_one() { + if !scheduler.input_numerators_claim.is_zero() { + return Err(InvalidClaimError); + } + + for round in rounds { + if !round.input_numerators_claim.is_one() { + return Err(InvalidClaimError); + } + } + + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { scheduler, rounds } = self; + + for round in rounds { + acc.accumulate(round.input_n_variables, round.input_denominators_claim); + } + + // TODO: Note `n_variables` is not verified. Probably fine since if the prover gives wrong + // info they'll be caught. Can panic though if the n_variables is too high. Consider + // checking the number of GKR layers in the verifier is less than + // LOG_CIRCLE_ORDER-LOG_BLOWUP-LOG_EXPAND. + acc.accumulate( + scheduler.input_n_variables, + scheduler.input_denominators_claim, + ); + } +} + +pub fn generate_lookup_instances( + log_size: u32, + lookup_data: BlakeSchedulerLookupData, + round_lookup_elements: &RoundElements, + blake_lookup_elements: &BlakeElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Generate scheduler interaction trace").entered(); + let size = 1 << log_size; + let mut round_lookup_layers = Vec::new(); + + for l0 in &lookup_data.round_lookups { + let mut denominators = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let denom = round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); + denominators.data[vec_row] = denom; + } + collection_for_univariate_iop.push(denominators.clone()); + round_lookup_layers.push(Layer::LogUpSingles { denominators }) + } + + // Blake hash lookup. + let blake_numers = Mle::::new(BaseColumn::zeros(size)); + let mut blake_denoms = Mle::::new(SecureColumn::zeros(size)); + for vec_row in 0..1 << (log_size - LOG_N_LANES) { + let blake_denom: PackedSecureField = blake_lookup_elements.combine( + &lookup_data + .blake_lookups + .each_ref() + .map(|l| l.data[vec_row]), + ); + blake_denoms.data[vec_row] = blake_denom; + } + collection_for_univariate_iop.push(blake_denoms.clone()); + round_lookup_layers.push(Layer::LogUpMultiplicities { + numerators: blake_numers, + denominators: blake_denoms, + }); + + round_lookup_layers +} diff --git a/crates/prover/src/examples/blake_gkr/xor_table.rs b/crates/prover/src/examples/blake_gkr/xor_table.rs new file mode 100644 index 000000000..18b5fe31e --- /dev/null +++ b/crates/prover/src/examples/blake_gkr/xor_table.rs @@ -0,0 +1,328 @@ +use std::array; +use std::iter::zip; +use std::simd::u32x16; + +use itertools::Itertools; +use num_traits::{One, Zero}; +use tracing::{span, Level}; + +use super::air::InvalidClaimError; +use super::gkr_lookups::accumulation::{MleClaimAccumulator, MleCollection}; +use super::gkr_lookups::AccumulatedMleCoeffColumnOracle; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator}; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::Field; +use crate::core::lookups::gkr_prover::Layer; +use crate::core::lookups::gkr_verifier::{LogUpArtifactInstance, LookupArtifactInstance}; +use crate::core::lookups::mle::Mle; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; +use crate::examples::blake::xor_table::{column_bits, limb_bits, XorElements, XorTableLookupData}; +use crate::examples::blake::BlakeXorElements; + +/// Component that evaluates the xor table. +pub type XorTableComponent = + FrameworkComponent>; + +/// Evaluates the xor table. +pub struct XorTableEval { + pub lookup_elements: XorElements, +} + +impl FrameworkEval + for XorTableEval +{ + fn log_size(&self) -> u32 { + column_bits::() + } + fn max_constraint_log_degree_bound(&self) -> u32 { + column_bits::() + 1 + } + fn evaluate(&self, mut eval: E) -> E { + let _ = eval_xor_table_multiplicity_cols::(&mut eval); + eval + } +} + +impl AccumulatedMleCoeffColumnOracle + for XorTableComponent +{ + fn accumulate_at_point( + &self, + _point: CirclePoint, + mask: &TreeVec>>, + acc: &mut PointEvaluationAccumulator, + ) { + // Create dummy point evaluator just to extract the value we need from the mask + let mut _accumulator = PointEvaluationAccumulator::new(SecureField::one()); + let mut eval = PointEvaluator::new( + mask.sub_tree(self.trace_locations()), + &mut _accumulator, + SecureField::one(), + ); + + for eval in eval_xor_table_multiplicity_cols::<_, ELEM_BITS, EXPAND_BITS>(&mut eval) { + acc.accumulate(eval) + } + } +} + +fn eval_xor_table_multiplicity_cols( + eval: &mut E, +) -> Vec { + (0..1 << (2 * EXPAND_BITS)) + .map(|_| eval.next_trace_mask()) + .collect() +} + +pub struct XorLookupArtifacts { + xor12: XorLookupArtifact<12, 4>, + xor9: XorLookupArtifact<9, 2>, + xor8: XorLookupArtifact<8, 2>, + xor7: XorLookupArtifact<7, 2>, + xor4: XorLookupArtifact<4, 0>, +} + +impl XorLookupArtifacts { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + xor12: XorLookupArtifact::new_from_iter(&mut iter), + xor9: XorLookupArtifact::new_from_iter(&mut iter), + xor8: XorLookupArtifact::new_from_iter(&mut iter), + xor7: XorLookupArtifact::new_from_iter(&mut iter), + xor4: XorLookupArtifact::new_from_iter(&mut iter), + } + } + + pub fn verify_succinct_mle_claims( + &self, + lookup_elements: &BlakeXorElements, + ) -> Result<(), InvalidClaimError> { + let Self { + xor12, + xor9, + xor8, + xor7, + xor4, + } = self; + + xor12.verify_succinct_mle_claims(lookup_elements.get(12))?; + xor9.verify_succinct_mle_claims(lookup_elements.get(9))?; + xor8.verify_succinct_mle_claims(lookup_elements.get(8))?; + xor7.verify_succinct_mle_claims(lookup_elements.get(7))?; + xor4.verify_succinct_mle_claims(lookup_elements.get(4))?; + + Ok(()) + } + + pub fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { + xor12, + xor9, + xor8, + xor7, + xor4, + } = self; + xor12.accumulate_mle_eval_iop_claims(acc); + xor9.accumulate_mle_eval_iop_claims(acc); + xor8.accumulate_mle_eval_iop_claims(acc); + xor7.accumulate_mle_eval_iop_claims(acc); + xor4.accumulate_mle_eval_iop_claims(acc); + } +} + +pub struct XorLookupArtifact { + /// `2^(2*EXPAND_BITS)` many LogUp instances. + artifacts: Vec, +} + +impl XorLookupArtifact { + pub fn new_from_iter(mut iter: impl Iterator) -> Self { + Self { + artifacts: (0..1 << (2 * EXPAND_BITS)) + .map(|_| match iter.next() { + // TODO: check input MLEs have expected number of variables. + Some(LookupArtifactInstance::LogUp(artifact)) => artifact, + _ => panic!(), + }) + .collect(), + } + } + + fn verify_succinct_mle_claims( + &self, + lookup_elements: &XorElements, + ) -> Result<(), InvalidClaimError> { + for (i, artifact) in self.artifacts.iter().enumerate() { + let eval_point = &artifact.eval_point; + let denoms_claim = artifact.input_denominators_claim; + let denoms_eval = eval_logup_denominators_mle::( + i, + lookup_elements, + eval_point, + ) + .unwrap(); + + if denoms_claim != denoms_eval { + return Err(InvalidClaimError); + } + } + + Ok(()) + } + + fn accumulate_mle_eval_iop_claims(&self, acc: &mut MleClaimAccumulator) { + let Self { artifacts } = self; + for artifact in artifacts { + acc.accumulate(artifact.input_n_variables, artifact.input_numerators_claim); + } + } +} + +pub fn generate_lookup_instances( + lookup_data: XorTableLookupData, + lookup_elements: &XorElements, + collection_for_univariate_iop: &mut MleCollection, +) -> Vec> { + let _span = span!(Level::INFO, "Xor interaction trace").entered(); + let mut xor_lookup_layers = Vec::new(); + + // There are 2^(2*EXPAND_BITS) columns, for each combination of ah, bh. + for (column_index, mults) in lookup_data.xor_accum.mults.iter().enumerate() { + let numerators = Mle::::new(mults.clone()); + collection_for_univariate_iop.push(numerators.clone()); + let denominators = + gen_logup_denominators_mle::(column_index, lookup_elements); + xor_lookup_layers.push(Layer::LogUpMultiplicities { + numerators, + denominators, + }); + } + + xor_lookup_layers +} + +/// Returns an MLE representing the LogUp denominator terms for the xor table. +fn gen_logup_denominators_mle( + column_index: usize, + lookup_elements: &XorElements, +) -> Mle { + let offsets_vec = u32x16::from_array(array::from_fn(|i| i as u32)); + let column_bits = column_bits::(); + let column_size = 1 << column_bits; + let mut denominators = Mle::::new(SecureColumn::zeros(column_size)); + + // Extract ah, bh from column index. + let ah = column_index as u32 >> EXPAND_BITS; + let bh = column_index as u32 & ((1 << EXPAND_BITS) - 1); + + // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. + for vec_row in 0..1 << (column_bits - LOG_N_LANES) { + let limb_bits = limb_bits::(); + + // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. + // Extract al, blh from vec_row. + let al = vec_row >> (limb_bits - LOG_N_LANES); + let blh = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1); + + // Construct the 3 vectors a, b, c. + let a = u32x16::splat((ah << limb_bits) | al); + // bll is just the consecutive numbers 0..N_LANES-1. + let b = u32x16::splat((bh << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec; + let c = a ^ b; + + let denominator = lookup_elements + .combine(&[a, b, c].map(|v| unsafe { PackedBaseField::from_simd_unchecked(v) })); + denominators.data[vec_row as usize] = denominator; + } + + denominators +} + +/// Evaluates the succinct MLE representing the LogUp denominator terms for the xor table. +/// +/// Evaluates the MLE returned by [`gen_logup_denominators_mle`]. +fn eval_logup_denominators_mle( + column_index: usize, + lookup_elements: &XorElements, + eval_point: &[SecureField], +) -> Result { + assert!(column_index < 1 << (2 * EXPAND_BITS)); + let limb_bits = limb_bits::() as usize; + if eval_point.len() != limb_bits * 2 { + return Err(InvalidEvalPoint); + } + + let (al_assignment, bl_assignment) = eval_point.split_at(limb_bits); + let cl_assignment = &zip(al_assignment, bl_assignment) + // Note `a ^ b = a + b - 2 * a * b` for all `a, b` in `{0, 1}`. + .map(|(&li, &ri)| li + ri - (li * ri).double()) + .collect_vec(); + + let al = pack_little_endian_bits(al_assignment); + let bl = pack_little_endian_bits(bl_assignment); + let cl = pack_little_endian_bits(cl_assignment); + + // Extract ah, bh from column index. + let ah = column_index >> EXPAND_BITS; + let bh = column_index & ((1 << EXPAND_BITS) - 1); + let ch = ah ^ bh; + + let a = al + BaseField::from(ah << limb_bits); + let b = bl + BaseField::from(bh << limb_bits); + let c = cl + BaseField::from(ch << limb_bits); + + Ok(lookup_elements.combine(&[a, b, c])) +} + +fn pack_little_endian_bits(bits: &[SecureField]) -> SecureField { + bits.iter() + .fold(SecureField::zero(), |acc, &bit| acc.double() + bit) +} + +/// Eval point is invalid. +#[derive(Debug)] +struct InvalidEvalPoint; + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use crate::core::channel::Channel; + use crate::core::test_utils::test_channel; + use crate::examples::blake::xor_table::XorElements; + use crate::examples::blake_gkr::xor_table::{ + eval_logup_denominators_mle, gen_logup_denominators_mle, + }; + + #[test] + fn eval_logup_denominators_mle_works() { + const ELEM_BITS: u32 = 8; + const EXPAND_BITS: u32 = 2; + let column_index = 0b1011; + assert!((0..1 << (2 * EXPAND_BITS)).contains(&column_index)); + let channel = &mut test_channel(); + let lookup_elements = XorElements::draw(channel); + let denominators_mle = + gen_logup_denominators_mle::(column_index, &lookup_elements); + let eval_point = (0..denominators_mle.n_variables()) + .map(|_| channel.draw_felt()) + .collect_vec(); + + let eval = eval_logup_denominators_mle::( + column_index, + &lookup_elements, + &eval_point, + ) + .unwrap(); + + assert_eq!(eval, denominators_mle.eval_at_point(&eval_point)); + } +} diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index 330662de9..0ad6f301a 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,5 +1,5 @@ pub mod blake; +pub mod blake_gkr; pub mod plonk; pub mod poseidon; pub mod wide_fibonacci; -pub mod xor; diff --git a/crates/prover/src/examples/xor/gkr_lookups/mod.rs b/crates/prover/src/examples/xor/gkr_lookups/mod.rs deleted file mode 100644 index 6ee603eb0..000000000 --- a/crates/prover/src/examples/xor/gkr_lookups/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod accumulation; -pub mod mle_eval; diff --git a/crates/prover/src/examples/xor/mod.rs b/crates/prover/src/examples/xor/mod.rs deleted file mode 100644 index 34e702a9b..000000000 --- a/crates/prover/src/examples/xor/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod gkr_lookups;