Skip to content

Commit

Permalink
Create blake component that uses GKR for lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Sep 24, 2024
1 parent a969689 commit 3df0cb9
Show file tree
Hide file tree
Showing 28 changed files with 1,791 additions and 83 deletions.
14 changes: 13 additions & 1 deletion crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::ops::Deref;
use itertools::Itertools;
use tracing::{span, Level};

use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use super::{EvalAtRow, EvalAtRowWithMle, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
Expand Down Expand Up @@ -57,6 +57,18 @@ impl TraceLocationAllocator {
}
}

/// A component defined solely in means of the constraints framework.
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
/// Note that the constraint framework only support components with columns of the same size.
pub trait FrameworkEvalWithMle {
fn log_size(&self) -> u32;

fn max_constraint_log_degree_bound(&self) -> u32;

fn evaluate<E: EvalAtRowWithMle>(&self, eval: E) -> E;
}

/// A component defined solely in means of the constraints framework.
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<const N: usize> LookupElements<N> {
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
where
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<Output = EF>,
{
zip_eq(values, self.alpha_powers).fold(EF::zero(), |acc, (&value, power)| {
acc + EF::from(power) * value
Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ pub trait EvalAtRow {
/// Combines 4 base field values into a single extension field value.
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;
}

trait EvalAtRowWithMle: EvalAtRow {
fn add_mle_coeff_col_eval(&mut self, eval: Self::EF);
}
53 changes: 53 additions & 0 deletions crates/prover/src/constraint_framework/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,56 @@ impl<'a> EvalAtRow for PointEvaluator<'a> {
SecureField::from_partial_evals(values)
}
}

// /// Evaluates expressions at a point out of domain.
// pub struct MleCoeffColEvalAccumulator<'a> {
// pub mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
// pub evaluation_accumulator: &'a mut PointEvaluationAccumulator,
// pub col_index: Vec<usize>,
// pub denom_inverse: SecureField,
// }
// impl<'a> MleCoeffColEvalAccumulator<'a> {
// pub fn new(
// mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
// evaluation_accumulator: &'a mut PointEvaluationAccumulator,
// denom_inverse: SecureField,
// ) -> Self {
// let col_index = vec![0; mask.len()];
// Self {
// mask,
// evaluation_accumulator,
// col_index,
// denom_inverse,
// }
// }
// }
// impl<'a> EvalAtRow for MleCoeffColEvalAccumulator<'a> {
// type F = SecureField;
// type EF = SecureField;

// fn next_interaction_mask<const N: usize>(
// &mut self,
// interaction: usize,
// _offsets: [isize; N],
// ) -> [Self::F; N] {
// let col_index = self.col_index[interaction];
// self.col_index[interaction] += 1;
// let mask = self.mask[interaction][col_index].clone();
// assert_eq!(mask.len(), N);
// mask.try_into().unwrap()
// }
// fn add_constraint<G>(&mut self, constraint: G)
// where
// Self::EF: Mul<G, Output = Self::EF>,
// {
// }
// fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
// SecureField::from_partial_evals(values)
// }
// }

// impl<'a> EvalAtRowWithMle for MleCoeffColEvalAccumulator<'a> {
// fn add_mle_coeff_col_eval(&mut self, eval: Self::EF) {
// todo!()
// }
// }
4 changes: 4 additions & 0 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ mod tests {
let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -354,6 +355,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -391,6 +393,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -427,6 +430,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/core/backend/simd/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -590,6 +591,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -629,6 +631,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -666,6 +669,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down
13 changes: 12 additions & 1 deletion crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;

use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask};
use super::gkr_verifier::{Gate, GkrArtifact, GkrBatchProof, GkrMask};
use super::mle::{Mle, MleOps};
use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
Expand Down Expand Up @@ -409,6 +409,16 @@ pub fn prove_batch<B: GkrOps>(
.collect_vec();
let n_layers = *n_layers_by_instance.iter().max().unwrap();

let gate_by_instance = input_layer_by_instance
.iter()
.map(|l| match l {
Layer::GrandProduct(_) => Gate::GrandProduct,
Layer::LogUpGeneric { .. }
| Layer::LogUpMultiplicities { .. }
| Layer::LogUpSingles { .. } => Gate::LogUp,
})
.collect();

// Evaluate all instance circuits and collect the layer values.
let mut layers_by_instance = input_layer_by_instance
.into_iter()
Expand Down Expand Up @@ -502,6 +512,7 @@ pub fn prove_batch<B: GkrOps>(

let artifact = GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance: n_layers_by_instance,
};
Expand Down
109 changes: 108 additions & 1 deletion crates/prover/src/core/lookups/gkr_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ pub fn partially_verify_batch(

Ok(GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(),
})
Expand All @@ -162,12 +163,114 @@ pub struct GkrBatchProof {
pub struct GkrArtifact {
/// Out-of-domain (OOD) point for evaluating columns in the input layer.
pub ood_point: Vec<SecureField>,
/// The gate of each instance.
pub gate_by_instance: Vec<Gate>,
/// The claimed evaluation at `ood_point` for each column in the input layer of each instance.
pub claims_to_verify_by_instance: Vec<Vec<SecureField>>,
/// The number of variables that interpolate the input layer of each instance.
pub n_variables_by_instance: Vec<usize>,
}

impl GkrArtifact {
pub fn ood_point(&self, instance_n_variables: usize) -> &[SecureField] {
&self.ood_point[self.ood_point.len() - instance_n_variables..]
}
}

pub struct LookupArtifactInstanceIter<'proof, 'artifact> {
instance: usize,
gkr_proof: &'proof GkrBatchProof,
gkr_artifact: &'artifact GkrArtifact,
}

impl<'proof, 'artifact> LookupArtifactInstanceIter<'proof, 'artifact> {
pub fn new(gkr_proof: &'proof GkrBatchProof, gkr_artifact: &'artifact GkrArtifact) -> Self {
Self {
instance: 0,
gkr_proof,
gkr_artifact,
}
}
}

impl<'proof, 'artifact> Iterator for LookupArtifactInstanceIter<'proof, 'artifact> {
type Item = LookupArtifactInstance;

fn next(&mut self) -> Option<LookupArtifactInstance> {
if self.instance >= self.gkr_proof.output_claims_by_instance.len() {
return None;
}

let instance = self.instance;
let input_n_variables = self.gkr_artifact.n_variables_by_instance[instance];
let eval_point = self.gkr_artifact.ood_point(input_n_variables).to_vec();
let output_claim = &*self.gkr_proof.output_claims_by_instance[instance];
let input_claims = &*self.gkr_artifact.claims_to_verify_by_instance[instance];
let gate = self.gkr_artifact.gate_by_instance[instance];

let res = Some(match gate {
Gate::LogUp => {
let [numerator, denominator] = output_claim.try_into().unwrap();
let claimed_sum = Fraction::new(numerator, denominator);
let [input_numerators_claim, input_denominators_claim] =
input_claims.try_into().unwrap();

LookupArtifactInstance::LogUp(LogUpArtifactInstance {
eval_point,
input_n_variables,
input_numerators_claim,
input_denominators_claim,
claimed_sum,
})
}
Gate::GrandProduct => {
let [claimed_product] = output_claim.try_into().unwrap();
let [input_claim] = input_claims.try_into().unwrap();

LookupArtifactInstance::GrandProduct(GrandProductArtifactInstance {
eval_point,
input_n_variables,
input_claim,
claimed_product,
})
}
});

self.instance += 1;
res
}
}

// TODO: Consider making the GKR artifact just a Vec<LookupArtifactInstance>.
pub enum LookupArtifactInstance {
GrandProduct(GrandProductArtifactInstance),
LogUp(LogUpArtifactInstance),
}

pub struct GrandProductArtifactInstance {
/// GKR input layer eval point.
pub eval_point: Vec<SecureField>,
/// Number of variables the MLE in the GKR input layer had.
pub input_n_variables: usize,
/// Claimed input MLE evaluation at `eval_point`.
pub input_claim: SecureField,
/// Output claim from the circuit.
pub claimed_product: SecureField,
}

pub struct LogUpArtifactInstance {
/// GKR input layer eval point.
pub eval_point: Vec<SecureField>,
/// Number of variables the MLEs in the GKR input layer had.
pub input_n_variables: usize,
/// Claimed input numerators MLE evaluation at `eval_point`.
pub input_numerators_claim: SecureField,
/// Claimed input denominators MLE evaluation at `eval_point`.
pub input_denominators_claim: SecureField,
/// Output claim from the circuit.
pub claimed_sum: Fraction<SecureField, SecureField>,
}

/// Defines how a circuit operates locally on two input rows to produce a single output row.
/// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure.
///
Expand All @@ -176,7 +279,7 @@ pub struct GkrArtifact {
/// circuit) GKR prover implementations.
///
/// [Thaler13]: https://eprint.iacr.org/2013/351.pdf
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Gate {
LogUp,
GrandProduct,
Expand Down Expand Up @@ -305,11 +408,13 @@ mod tests {

let GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance,
} = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?;

assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]);
assert_eq!(gate_by_instance, [Gate::GrandProduct, Gate::GrandProduct]);
assert_eq!(proof.output_claims_by_instance.len(), 2);
assert_eq!(claims_to_verify_by_instance.len(), 2);
assert_eq!(proof.output_claims_by_instance[0], &[product0]);
Expand Down Expand Up @@ -338,11 +443,13 @@ mod tests {

let GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance,
} = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?;

assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]);
assert_eq!(gate_by_instance, [Gate::GrandProduct, Gate::GrandProduct]);
assert_eq!(proof.output_claims_by_instance.len(), 2);
assert_eq!(claims_to_verify_by_instance.len(), 2);
assert_eq!(proof.output_claims_by_instance[0], &[product0]);
Expand Down
Loading

0 comments on commit 3df0cb9

Please sign in to comment.