diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 1d893d6ef..f0d6ca9be 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -34,7 +34,7 @@ pub trait EvalAtRow { + Debug + Zero + Neg - + AddAssign + + AddAssign + AddAssign + Add + Sub @@ -52,6 +52,7 @@ pub trait EvalAtRow { + Zero + From + Neg + + AddAssign + Add + Sub + Mul diff --git a/crates/prover/src/core/poly/circle/secure_poly.rs b/crates/prover/src/core/poly/circle/secure_poly.rs index 8482e2971..a503bd2c6 100644 --- a/crates/prover/src/core/poly/circle/secure_poly.rs +++ b/crates/prover/src/core/poly/circle/secure_poly.rs @@ -73,6 +73,13 @@ impl, EvalOrder> SecureEvaluation { _eval_order: PhantomData, } } + + pub fn into_coordinate_evals( + self, + ) -> [CircleEvaluation; SECURE_EXTENSION_DEGREE] { + let Self { domain, values, .. } = self; + values.columns.map(|c| CircleEvaluation::new(domain, c)) + } } impl, EvalOrder> Deref for SecureEvaluation { diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index ec8537595..5a5d60535 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -1,11 +1,24 @@ //! Multilinear extension (MLE) eval at point constraints. +// TODO(andrew): Remove in downstream PR. +#![allow(dead_code)] + use std::array; +use itertools::Itertools; use num_traits::{One, Zero}; use crate::constraint_framework::EvalAtRow; +use crate::core::backend::simd::SimdBackend; +use crate::core::circle::{CirclePoint, Coset}; +use crate::core::constraints::{coset_vanishing, point_vanishing}; +use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumnByCoords; +use crate::core::fields::{Field, FieldExpOps}; use crate::core::lookups::utils::eq; +use crate::core::poly::circle::{CanonicCoset, SecureEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; /// Evaluates constraints that guarantee an MLE evaluates to a claim at a given point. /// @@ -13,14 +26,20 @@ use crate::core::lookups::utils::eq; /// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`. pub fn eval_mle_eval_constraints( mle_interaction: usize, - selector_interaction: usize, + const_interaction: usize, eval: &mut E, mle_coeffs_col_eval: E::EF, mle_eval_point: MleEvalPoint, mle_claim_shift: SecureField, + carry_quotients_col_eval: E::EF, ) { - let eq_col_eval = - eval_eq_constraints(mle_interaction, selector_interaction, eval, mle_eval_point); + let eq_col_eval = eval_eq_constraints( + mle_interaction, + const_interaction, + eval, + mle_eval_point, + carry_quotients_col_eval, + ); let terms_col_eval = mle_coeffs_col_eval * eq_col_eval; eval_prefix_sum_constraints(mle_interaction, eval, terms_col_eval, mle_claim_shift) } @@ -34,7 +53,7 @@ pub struct MleEvalPoint { // Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`. eq_carry_quotients: [SecureField; N_VARIABLES], // Point `p`. - _p: [SecureField; N_VARIABLES], + p: [SecureField; N_VARIABLES], } impl MleEvalPoint { @@ -53,7 +72,7 @@ impl MleEvalPoint { denom_assignment[i] = one; eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) }), - _p: p, + p, } } } @@ -68,12 +87,13 @@ impl MleEvalPoint { /// See (Section 5.1). fn eval_eq_constraints( eq_interaction: usize, - selector_interaction: usize, + const_interaction: usize, eval: &mut E, mle_eval_point: MleEvalPoint, + carry_quotients_col_eval: E::EF, ) -> E::EF { let [curr, next_next] = eval.next_extension_interaction_mask(eq_interaction, [0, 2]); - let [is_first, is_second] = eval.next_interaction_mask(selector_interaction, [0, -1]); + let [is_first, is_second] = eval.next_interaction_mask(const_interaction, [0, -1]); // Check the initial value on half_coset0 and final value on half_coset1. // Combining these constraints is safe because `is_first` and `is_second` are never @@ -82,20 +102,8 @@ fn eval_eq_constraints( let half_coset1_final_check = (curr - mle_eval_point.eq_1_p) * is_second; eval.add_constraint(half_coset0_initial_check + half_coset1_final_check); - // Check all variables except the last (last variable is handled by the constraint above). - #[allow(clippy::needless_range_loop)] - for variable_i in 0..N_VARIABLES.saturating_sub(1) { - let half_coset0_next = next_next; - let half_coset1_prev = next_next; - let [half_coset0_step, half_coset1_step] = - eval.next_interaction_mask(selector_interaction, [0, -1]); - let carry_quotient = mle_eval_point.eq_carry_quotients[variable_i]; - // Safe to combine these constraints as `is_step.half_coset0` and `is_step.half_coset1` - // are never non-zero at the same time on the trace. - let half_coset0_check = (curr - half_coset0_next * carry_quotient) * half_coset0_step; - let half_coset1_check = (curr * carry_quotient - half_coset1_prev) * half_coset1_step; - eval.add_constraint(half_coset0_check + half_coset1_check); - } + // Check all the steps. + eval.add_constraint(curr - next_next * carry_quotients_col_eval); curr } @@ -114,6 +122,137 @@ fn eval_prefix_sum_constraints( eval.add_constraint(curr - prev - row_diff + cumulative_sum_shift); } +/// Returns succinct Eq carry quotients column. +/// +/// Given column `c(P)` defined on a [`CircleDomain`] `D = +-C`, and an MLE eval point +/// `(r0, r1, ...)` let `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. This function +/// returns column `q(P)` such that all `c(C[i]) = c(C[i + 1]) * q(C[i])` and +/// `c(-C[i]) = c(-C[i + 1]) * q(-C[i])`. +/// +/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain +fn gen_carry_quotient_col( + eval_point: &MleEvalPoint, +) -> SecureEvaluation { + let (half_coset0_carry_quotients, half_coset1_carry_quotients) = + gen_half_coset_carry_quotients(eval_point); + + let log_size = N_VARIABLES as u32; + let size = 1 << log_size; + let half_coset_size = size / 2; + let mut col = SecureColumnByCoords::::zeros(size); + + // TODO(andrew): Optimize. + for i in 0..half_coset_size { + let half_coset0_index = coset_index_to_circle_domain_index(i * 2, log_size); + let half_coset1_index = coset_index_to_circle_domain_index(i * 2 + 1, log_size); + let half_coset0_index_bit_rev = bit_reverse_index(half_coset0_index, log_size); + let half_coset1_index_bit_rev = bit_reverse_index(half_coset1_index, log_size); + + let n_trailing_ones = i.trailing_ones() as usize; + let half_coset0_carry_quotient = half_coset0_carry_quotients[n_trailing_ones]; + let half_coset1_carry_quotient = half_coset1_carry_quotients[n_trailing_ones]; + + col.set(half_coset0_index_bit_rev, half_coset0_carry_quotient); + col.set(half_coset1_index_bit_rev, half_coset1_carry_quotient); + } + + let domain = CanonicCoset::new(log_size).circle_domain(); + SecureEvaluation::new(domain, col) +} + +/// Evaluates the succinct Eq carry quotients column at point `p`. +/// +/// See [`gen_carry_quotient_col`]. +// TODO(andrew): Optimize further. Inline `eval_step_selector` and get runtime down to +// O(N_VARIABLES) vs current O(N_VARIABLES^2). Can also use vanishing evals to compute +// half_coset0_last half_coset1_first. +fn eval_carry_quotient_col( + eval_point: &MleEvalPoint, + p: CirclePoint, +) -> SecureField { + let log_size = N_VARIABLES as u32; + let coset = CanonicCoset::new(log_size).coset(); + + let (half_coset0_carry_quotients, half_coset1_carry_quotients) = + gen_half_coset_carry_quotients(eval_point); + + let mut eval = SecureField::zero(); + + for variable_i in 0..N_VARIABLES.saturating_sub(1) { + let log_step = variable_i as u32 + 2; + let offset = (1 << (log_step - 1)) - 2; + let half_coset0_selector = eval_step_selector_with_offset(coset, offset, log_step, p); + let half_coset1_selector = eval_step_selector_with_offset(coset, offset + 1, log_step, p); + let half_coset0_carry_quotient = half_coset0_carry_quotients[variable_i]; + let half_coset1_carry_quotient = half_coset1_carry_quotients[variable_i]; + eval += half_coset0_selector * half_coset0_carry_quotient; + eval += half_coset1_selector * half_coset1_carry_quotient; + } + + let half_coset0_last = eval_is_first(coset, p + coset.step.double().into_ef()); + let half_coset1_first = eval_is_first(coset, p + coset.step.into_ef()); + eval += *half_coset0_carry_quotients.last().unwrap() * half_coset0_last; + eval += *half_coset1_carry_quotients.last().unwrap() * half_coset1_first; + + eval +} + +/// Evaluates a polynomial that's `1` every `2^log_step` coset points, shifted by an offset, and `0` +/// elsewhere on coset. +fn eval_step_selector_with_offset( + coset: Coset, + offset: usize, + log_step: u32, + p: CirclePoint, +) -> SecureField { + let offset_step = coset.step.mul(offset as u128); + eval_step_selector(coset, log_step, p - offset_step.into_ef()) +} + +/// Evaluates a polynomial that's `1` every `2^log_step` coset points and `0` elsewhere on coset. +fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint) -> SecureField { + if log_step == 0 { + return SecureField::one(); + } + + // Rotate the coset to have points on the `x` axis. + let p = p - coset.initial.into_ef(); + let mut vanish_at_log_step = (0..coset.log_size) + .scan(p, |p, _| { + let res = *p; + *p = p.double(); + Some(res.y) + }) + .collect_vec(); + vanish_at_log_step.reverse(); + // We only need the first `log_step` many values. + vanish_at_log_step.truncate(log_step as usize); + let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()]; + SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv); + + let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square(); + let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::(); + (half_coset_selector_dbl + vanish_at_log_step[0] * vanish_substep_inv_sum.double()) + / BaseField::from(1 << (log_step + 1)) +} + +fn eval_is_first(coset: Coset, p: CirclePoint) -> SecureField { + coset_vanishing(coset, p) + / (point_vanishing(coset.initial, p) * BaseField::from(1 << coset.log_size)) +} + +/// Output of the form: `(half_coset0_carry_quotients, half_coset1_carry_quotients)`. +fn gen_half_coset_carry_quotients( + eval_point: &MleEvalPoint, +) -> ([SecureField; N_VARIABLES], [SecureField; N_VARIABLES]) { + let last_variable = *eval_point.p.last().unwrap(); + let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients; + *half_coset0_carry_quotients.last_mut().unwrap() *= + eq(&[SecureField::one()], &[last_variable]) / eq(&[SecureField::zero()], &[last_variable]); + let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse()); + (half_coset0_carry_quotients, half_coset1_carry_quotients) +} + #[cfg(test)] mod tests { use std::array; @@ -125,7 +264,8 @@ mod tests { use rand::{Rng, SeedableRng}; use super::{ - eval_eq_constraints, eval_mle_eval_constraints, eval_prefix_sum_constraints, MleEvalPoint, + eval_carry_quotient_col, eval_eq_constraints, eval_mle_eval_constraints, + eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint, }; use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset}; use crate::constraint_framework::{assert_constraints, EvalAtRow}; @@ -134,22 +274,24 @@ mod tests { use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column}; + use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::lookups::gkr_prover::GkrOps; use crate::core::lookups::mle::Mle; use crate::core::pcs::TreeVec; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; - - const EVAL_TRACE: usize = 0; - const CONST_TRACE: usize = 1; + use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; #[test] fn test_mle_eval_constraints_with_log_size_5() { const N_VARIABLES: usize = 5; + const EVAL_TRACE: usize = 0; + const CARRY_QUOTIENTS_TRACE: usize = 1; + const CONST_TRACE: usize = 2; let mut rng = SmallRng::seed_from_u64(0); let log_size = N_VARIABLES as u32; let size = 1 << log_size; @@ -159,13 +301,18 @@ mod tests { let base_trace = gen_base_trace(&mle, &eval_point); let claim = mle.eval_at_point(&eval_point); let claim_shift = claim / BaseField::from(size); - let constants_trace = gen_constants_trace(N_VARIABLES); - let traces = TreeVec::new(vec![base_trace, constants_trace]); + let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) + .into_coordinate_evals() + .to_vec(); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(log_size); assert_constraints(&trace_polys, trace_domain, |mut eval| { let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); eval_mle_eval_constraints( EVAL_TRACE, CONST_TRACE, @@ -173,6 +320,7 @@ mod tests { mle_coeff_col_eval, mle_eval_point, claim_shift, + carry_quotients_col_eval, ) }); } @@ -181,57 +329,99 @@ mod tests { #[ignore = "SimdBackend `MIN_FFT_LOG_SIZE` is 5"] fn eq_constraints_with_4_variables() { const N_VARIABLES: usize = 4; + const EVAL_TRACE: usize = 0; + const CARRY_QUOTIENTS_TRACE: usize = 1; + const CONST_TRACE: usize = 2; let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); let base_trace = gen_base_trace(&mle, &eval_point); - let constants_trace = gen_constants_trace(N_VARIABLES); - let traces = TreeVec::new(vec![base_trace, constants_trace]); + let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) + .into_coordinate_evals() + .to_vec(); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(eval_point.len() as u32); - let mle_eval_point = MleEvalPoint::new(eval_point); assert_constraints(&trace_polys, trace_domain, |mut eval| { let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - eval_eq_constraints(EVAL_TRACE, CONST_TRACE, &mut eval, mle_eval_point); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + eval_eq_constraints( + EVAL_TRACE, + CONST_TRACE, + &mut eval, + mle_eval_point, + carry_quotients_col_eval, + ); }); } #[test] fn eq_constraints_with_5_variables() { const N_VARIABLES: usize = 5; + const EVAL_TRACE: usize = 0; + const CARRY_QUOTIENTS_TRACE: usize = 1; + const CONST_TRACE: usize = 2; let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); let base_trace = gen_base_trace(&mle, &eval_point); - let constants_trace = gen_constants_trace(N_VARIABLES); - let traces = TreeVec::new(vec![base_trace, constants_trace]); + let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) + .into_coordinate_evals() + .to_vec(); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(eval_point.len() as u32); - let mle_eval_point = MleEvalPoint::new(eval_point); assert_constraints(&trace_polys, trace_domain, |mut eval| { let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - eval_eq_constraints(EVAL_TRACE, CONST_TRACE, &mut eval, mle_eval_point); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + eval_eq_constraints( + EVAL_TRACE, + CONST_TRACE, + &mut eval, + mle_eval_point, + carry_quotients_col_eval, + ); }); } #[test] fn eq_constraints_with_8_variables() { const N_VARIABLES: usize = 8; + const EVAL_TRACE: usize = 0; + const CARRY_QUOTIENTS_TRACE: usize = 1; + const CONST_TRACE: usize = 2; let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); let base_trace = gen_base_trace(&mle, &eval_point); - let constants_trace = gen_constants_trace(N_VARIABLES); - let traces = TreeVec::new(vec![base_trace, constants_trace]); + let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) + .into_coordinate_evals() + .to_vec(); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(eval_point.len() as u32); - let mle_eval_point = MleEvalPoint::new(eval_point); assert_constraints(&trace_polys, trace_domain, |mut eval| { let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - eval_eq_constraints(EVAL_TRACE, CONST_TRACE, &mut eval, mle_eval_point); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + eval_eq_constraints( + EVAL_TRACE, + CONST_TRACE, + &mut eval, + mle_eval_point, + carry_quotients_col_eval, + ); }); } @@ -248,17 +438,47 @@ mod tests { assert_constraints(&trace_polys, trace_domain, |mut eval| { let [row_diff] = eval.next_extension_interaction_mask(0, [0]); - eval_prefix_sum_constraints(EVAL_TRACE, &mut eval, row_diff, cumulative_sum_shift) + eval_prefix_sum_constraints(0, &mut eval, row_diff, cumulative_sum_shift) }); } + #[test] + fn eval_step_selector_with_offset_works() { + const LOG_SIZE: u32 = 5; + const OFFSET: usize = 1; + const LOG_STEP: u32 = 2; + let coset = CanonicCoset::new(LOG_SIZE).coset(); + let col_eval = gen_is_step_with_offset::(LOG_SIZE, LOG_STEP, OFFSET); + let col_poly = col_eval.interpolate(); + let p = SECURE_FIELD_CIRCLE_GEN; + + let eval = eval_step_selector_with_offset(coset, OFFSET, LOG_STEP, p); + + assert_eq!(eval, col_poly.eval_at_point(p)); + } + + #[test] + fn eval_carry_quotient_col_works() { + const N_VARIABLES: usize = 5; + let mut rng = SmallRng::seed_from_u64(0); + let mle_eval_point = MleEvalPoint::::new(array::from_fn(|_| rng.gen())); + let col_eval = gen_carry_quotient_col(&mle_eval_point); + let twiddles = SimdBackend::precompute_twiddles(col_eval.domain.half_coset); + let col_poly = col_eval.interpolate_with_twiddles(&twiddles); + let p = SECURE_FIELD_CIRCLE_GEN; + + let eval = eval_carry_quotient_col(&mle_eval_point, p); + + assert_eq!(eval, col_poly.eval_at_point(p)); + } + /// Generates a trace. /// /// Trace structure: /// /// ```text /// ------------------------------------------------------------------------------------- - /// | MLE coeffs | eq evals (basis) | MLE terms (prefix sum) | + /// | MLE coeffs | EqEvals (basis) | MLE terms (prefix sum) | /// ------------------------------------------------------------------------------------- /// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c9 | c9 | c10 | c11 | /// ------------------------------------------------------------------------------------- @@ -343,25 +563,9 @@ mod tests { } } - fn gen_constants_trace( - n_variables: usize, + fn gen_constants_trace( ) -> Vec> { - let log_size = n_variables as u32; - let mut constants_trace = Vec::new(); - constants_trace.push(gen_is_first(log_size)); - - // TODO(andrew): Note the last selector column is not needed. The column for `is_first` - // with an offset for each half coset midpoint can be used instead. - for variable_i in 1..n_variables as u32 { - let half_coset_log_step = variable_i; - let half_coset_offset = (1 << (half_coset_log_step - 1)) - 1; - - let log_step = half_coset_log_step + 1; - let offset = half_coset_offset * 2; - - constants_trace.push(gen_is_step_with_offset(log_size, log_step, offset)) - } - - constants_trace + let log_size = N_VARIABLES as u32; + vec![gen_is_first(log_size)] } }