From a530f5ab446d815a991a8f7901dfccb46abf39a7 Mon Sep 17 00:00:00 2001 From: einar-taiko <126954546+einar-taiko@users.noreply.github.com> Date: Mon, 26 Jun 2023 23:24:24 +0900 Subject: [PATCH 1/6] Resolve Prover optimization: memory reduction #77 (#6) * Resolve taikoxyz/zkevm-circuits#77 * Please Clippy --- halo2_proofs/Cargo.toml | 5 +- halo2_proofs/benches/fft.rs | 18 +- halo2_proofs/examples/shuffle_api.rs | 207 ---- halo2_proofs/src/arithmetic.rs | 138 +-- halo2_proofs/src/dev.rs | 122 -- halo2_proofs/src/dev/failure.rs | 216 ---- halo2_proofs/src/fft.rs | 119 ++ halo2_proofs/src/fft/baseline.rs | 128 +++ halo2_proofs/src/fft/parallel.rs | 282 +++++ halo2_proofs/src/fft/recursive.rs | 461 ++++++++ halo2_proofs/src/lib.rs | 1 + halo2_proofs/src/plonk.rs | 23 +- halo2_proofs/src/plonk/circuit.rs | 124 ++- halo2_proofs/src/plonk/evaluation.rs | 1052 ++++++++++++------ halo2_proofs/src/plonk/keygen.rs | 16 +- halo2_proofs/src/plonk/permutation.rs | 16 +- halo2_proofs/src/plonk/permutation/keygen.rs | 12 - halo2_proofs/src/plonk/permutation/prover.rs | 6 +- halo2_proofs/src/plonk/prover.rs | 54 +- halo2_proofs/src/plonk/verifier.rs | 162 +-- halo2_proofs/src/poly/domain.rs | 783 ++++++------- 21 files changed, 2240 insertions(+), 1705 deletions(-) delete mode 100644 halo2_proofs/examples/shuffle_api.rs create mode 100644 halo2_proofs/src/fft.rs create mode 100644 halo2_proofs/src/fft/baseline.rs create mode 100644 halo2_proofs/src/fft/parallel.rs create mode 100644 halo2_proofs/src/fft/recursive.rs diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index 276dfb2563..0118b6b361 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -49,6 +49,7 @@ name = "fft" harness = false [dependencies] +itertools = "0.10" backtrace = { version = "0.3", optional = true } crossbeam = "0.8" ff = "0.13" @@ -78,6 +79,7 @@ gumdrop = "0.8" proptest = "1" rand_core = { version = "0.6", features = ["getrandom"] } rand_chacha = "0.3.1" +ark-std = { version = "0.3.0", features = ["print-trace"] } [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } @@ -108,6 +110,3 @@ name = "serialization" [[example]] name = "shuffle" - -[[example]] -name = "shuffle_api" diff --git a/halo2_proofs/benches/fft.rs b/halo2_proofs/benches/fft.rs index 0de72a0380..944425a380 100644 --- a/halo2_proofs/benches/fft.rs +++ b/halo2_proofs/benches/fft.rs @@ -3,20 +3,28 @@ extern crate criterion; use crate::arithmetic::best_fft; use group::ff::Field; -use halo2_proofs::*; -use halo2curves::pasta::Fp; +use halo2_proofs::{poly::EvaluationDomain, *}; +use halo2curves::bn256::Fr as Scalar; use criterion::{BenchmarkId, Criterion}; use rand_core::OsRng; fn criterion_benchmark(c: &mut Criterion) { + let j = 5; let mut group = c.benchmark_group("fft"); for k in 3..19 { + let domain = EvaluationDomain::new(j, k); + let omega = domain.get_omega(); + let l = 1 << k; + let data = domain.get_fft_data(l); + group.bench_function(BenchmarkId::new("k", k), |b| { - let mut a = (0..(1 << k)).map(|_| Fp::random(OsRng)).collect::>(); - let omega = Fp::random(OsRng); // would be weird if this mattered + let mut a = (0..(1 << k)) + .map(|_| Scalar::random(OsRng)) + .collect::>(); + b.iter(|| { - best_fft(&mut a, omega, k as u32); + best_fft(&mut a, omega, k, data, false); }); }); } diff --git a/halo2_proofs/examples/shuffle_api.rs b/halo2_proofs/examples/shuffle_api.rs deleted file mode 100644 index d3de4c2cdb..0000000000 --- a/halo2_proofs/examples/shuffle_api.rs +++ /dev/null @@ -1,207 +0,0 @@ -use std::{hash::Hash, marker::PhantomData, vec}; - -use ff::{FromUniformBytes, WithSmallOrderMulGroup}; -use halo2_proofs::{ - arithmetic::Field, - circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{ - create_proof, keygen_pk, keygen_vk, verify_proof, Advice, Circuit, Column, - ConstraintSystem, Error, Fixed, Selector, - }, - poly::{ - commitment::ParamsProver, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - Rotation, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, -}; -use halo2curves::{bn256::Bn256, serde::SerdeObject, CurveAffine}; -use pairing::MultiMillerLoop; -use rand_core::OsRng; - -struct ShuffleChip { - config: ShuffleConfig, - _marker: PhantomData, -} - -#[derive(Clone, Debug)] -struct ShuffleConfig { - input_0: Column, - input_1: Column, - shuffle_0: Column, - shuffle_1: Column, - s_input: Selector, - s_shuffle: Selector, -} - -impl ShuffleChip { - fn construct(config: ShuffleConfig) -> Self { - Self { - config, - _marker: PhantomData, - } - } - - fn configure( - meta: &mut ConstraintSystem, - input_0: Column, - input_1: Column, - shuffle_0: Column, - shuffle_1: Column, - ) -> ShuffleConfig { - let s_shuffle = meta.complex_selector(); - let s_input = meta.complex_selector(); - meta.shuffle("shuffle", |meta| { - let s_input = meta.query_selector(s_input); - let s_shuffle = meta.query_selector(s_shuffle); - let input_0 = meta.query_advice(input_0, Rotation::cur()); - let input_1 = meta.query_fixed(input_1, Rotation::cur()); - let shuffle_0 = meta.query_advice(shuffle_0, Rotation::cur()); - let shuffle_1 = meta.query_advice(shuffle_1, Rotation::cur()); - vec![ - (s_input.clone() * input_0, s_shuffle.clone() * shuffle_0), - (s_input * input_1, s_shuffle * shuffle_1), - ] - }); - ShuffleConfig { - input_0, - input_1, - shuffle_0, - shuffle_1, - s_input, - s_shuffle, - } - } -} - -#[derive(Default)] -struct MyCircuit { - input_0: Vec>, - input_1: Vec, - shuffle_0: Vec>, - shuffle_1: Vec>, -} - -impl Circuit for MyCircuit { - // Since we are using a single chip for everything, we can just reuse its config. - type Config = ShuffleConfig; - type FloorPlanner = SimpleFloorPlanner; - #[cfg(feature = "circuit-params")] - fn params(&self) -> Self::Params {} - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let input_0 = meta.advice_column(); - let input_1 = meta.fixed_column(); - let shuffle_0 = meta.advice_column(); - let shuffle_1 = meta.advice_column(); - ShuffleChip::configure(meta, input_0, input_1, shuffle_0, shuffle_1) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let ch = ShuffleChip::::construct(config); - layouter.assign_region( - || "load inputs & shuffles", - |mut region| { - for (i, (input_0, input_1)) in - self.input_0.iter().zip(self.input_1.iter()).enumerate() - { - region.assign_advice(ch.config.input_0, i, *input_0); - region.assign_fixed(ch.config.input_1, i, *input_1); - ch.config.s_input.enable(&mut region, i)?; - } - - for (i, (shuffle_0, shuffle_1)) in - self.shuffle_0.iter().zip(self.shuffle_1.iter()).enumerate() - { - region.assign_advice(ch.config.shuffle_0, i, *shuffle_0); - region.assign_advice(ch.config.shuffle_1, i, *shuffle_1); - ch.config.s_shuffle.enable(&mut region, i)?; - } - Ok(()) - }, - )?; - Ok(()) - } -} - -fn test_prover(k: u32, circuit: MyCircuit, expected: bool) -where - E::Fr: Hash + FromUniformBytes<64> + WithSmallOrderMulGroup<3>, - E::G1Affine: CurveAffine + SerdeObject, - E::G2Affine: CurveAffine + SerdeObject, -{ - let params = ParamsKZG::::new(k); - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); - - let proof = { - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - - create_proof::, ProverSHPLONK, _, _, _, _>( - ¶ms, - &pk, - &[circuit], - &[&[]], - OsRng, - &mut transcript, - ) - .expect("proof generation should not fail"); - - transcript.finalize() - }; - - let accepted = { - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - - verify_proof::, VerifierSHPLONK, _, _, _>( - ¶ms, - pk.get_vk(), - strategy, - &[&[]], - &mut transcript, - ) - .is_ok() - }; - - assert_eq!(accepted, expected); -} - -fn main() { - use halo2_proofs::dev::MockProver; - use halo2curves::bn256::Fr; - const K: u32 = 4; - let input_0 = [1, 2, 4, 1] - .map(|e: u64| Value::known(Fr::from(e))) - .to_vec(); - let input_1 = [10, 20, 40, 10].map(Fr::from).to_vec(); - let shuffle_0 = [4, 1, 1, 2] - .map(|e: u64| Value::known(Fr::from(e))) - .to_vec(); - let shuffle_1 = [40, 10, 10, 20] - .map(|e: u64| Value::known(Fr::from(e))) - .to_vec(); - let circuit = MyCircuit { - input_0, - input_1, - shuffle_0, - shuffle_1, - }; - let prover = MockProver::run(K, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - test_prover::(K, circuit, true); -} diff --git a/halo2_proofs/src/arithmetic.rs b/halo2_proofs/src/arithmetic.rs index 3a67afbcb2..6c026bbcba 100644 --- a/halo2_proofs/src/arithmetic.rs +++ b/halo2_proofs/src/arithmetic.rs @@ -214,132 +214,31 @@ pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu res } -/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size -/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative -/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when -/// interpreted as the coefficients of a polynomial of degree $n - 1$, is -/// transformed into the evaluations of this polynomial at each of the $n$ -/// distinct powers of $\omega$. This transformation is invertible by providing -/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element -/// by $n$. -/// -/// This will use multithreading if beneficial. -pub fn best_fft>(a: &mut [G], omega: Scalar, log_n: u32) { - fn bitreverse(mut n: usize, l: usize) -> usize { - let mut r = 0; - for _ in 0..l { - r = (r << 1) | (n & 1); - n >>= 1; - } - r - } - - let threads = multicore::current_num_threads(); - let log_threads = log2_floor(threads); - let n = a.len(); - assert_eq!(n, 1 << log_n); - - for k in 0..n { - let rk = bitreverse(k, log_n as usize); - if k < rk { - a.swap(rk, k); - } - } - - //let start = start_measure(format!("twiddles {} ({})", a.len(), threads), false); - // precompute twiddle factors - let twiddles: Vec<_> = (0..(n / 2)) - .scan(Scalar::ONE, |w, _| { - let tw = *w; - *w *= ω - Some(tw) - }) - .collect(); - - if log_n <= log_threads { - let mut chunk = 2_usize; - let mut twiddle_chunk = n / 2; - for _ in 0..log_n { - a.chunks_mut(chunk).for_each(|coeffs| { - let (left, right) = coeffs.split_at_mut(chunk / 2); - - // case when twiddle factor is one - let (a, left) = left.split_at_mut(1); - let (b, right) = right.split_at_mut(1); - let t = b[0]; - b[0] = a[0]; - a[0] += &t; - b[0] -= &t; - - left.iter_mut() - .zip(right.iter_mut()) - .enumerate() - .for_each(|(i, (a, b))| { - let mut t = *b; - t *= &twiddles[(i + 1) * twiddle_chunk]; - *b = *a; - *a += &t; - *b -= &t; - }); - }); - chunk *= 2; - twiddle_chunk /= 2; - } - } else { - recursive_butterfly_arithmetic(a, n, 1, &twiddles) - } -} - -/// This perform recursive butterfly arithmetic -pub fn recursive_butterfly_arithmetic>( +/// Dispatcher +pub fn best_fft>( a: &mut [G], - n: usize, - twiddle_chunk: usize, - twiddles: &[Scalar], + omega: Scalar, + log_n: u32, + data: &FFTData, + inverse: bool, ) { - if n == 2 { - let t = a[1]; - a[1] = a[0]; - a[0] += &t; - a[1] -= &t; - } else { - let (left, right) = a.split_at_mut(n / 2); - multicore::join( - || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles), - || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles), - ); - - // case when twiddle factor is one - let (a, left) = left.split_at_mut(1); - let (b, right) = right.split_at_mut(1); - let t = b[0]; - b[0] = a[0]; - a[0] += &t; - b[0] -= &t; - - left.iter_mut() - .zip(right.iter_mut()) - .enumerate() - .for_each(|(i, (a, b))| { - let mut t = *b; - t *= &twiddles[(i + 1) * twiddle_chunk]; - *b = *a; - *a += &t; - *b -= &t; - }); - } + fft::fft(a, omega, log_n, data, inverse); } /// Convert coefficient bases group elements to lagrange basis by inverse FFT. pub fn g_to_lagrange(g_projective: Vec, k: u32) -> Vec { let n_inv = C::Scalar::TWO_INV.pow_vartime([k as u64, 0, 0, 0]); + let omega = C::Scalar::ROOT_OF_UNITY; let mut omega_inv = C::Scalar::ROOT_OF_UNITY_INV; for _ in k..C::Scalar::S { omega_inv = omega_inv.square(); } let mut g_lagrange_projective = g_projective; - best_fft(&mut g_lagrange_projective, omega_inv, k); + let n = g_lagrange_projective.len(); + let fft_data = FFTData::new(n, omega, omega_inv); + + best_fft(&mut g_lagrange_projective, omega_inv, k, &fft_data, false); parallelize(&mut g_lagrange_projective, |g, _| { for g in g.iter_mut() { *g *= n_inv; @@ -498,7 +397,7 @@ pub fn parallelize_count( }); } -fn log2_floor(num: usize) -> u32 { +pub fn log2_floor(num: usize) -> u32 { assert!(num > 0); let mut pow = 0; @@ -592,9 +491,20 @@ pub(crate) fn powers(base: F) -> impl Iterator { std::iter::successors(Some(F::ONE), move |power| Some(base * power)) } +/// Reverse `l` LSBs of bitvector `n` +pub fn bitreverse(mut n: usize, l: usize) -> usize { + let mut r = 0; + for _ in 0..l { + r = (r << 1) | (n & 1); + n >>= 1; + } + r +} + #[cfg(test)] use rand_core::OsRng; +use crate::fft::{self, recursive::FFTData}; #[cfg(test)] use crate::halo2curves::pasta::Fp; // use crate::plonk::{get_duration, get_time, start_measure, stop_measure}; diff --git a/halo2_proofs/src/dev.rs b/halo2_proofs/src/dev.rs index 05d2974b54..23cf45ba1a 100644 --- a/halo2_proofs/src/dev.rs +++ b/halo2_proofs/src/dev.rs @@ -1057,66 +1057,6 @@ impl + Ord> MockProver { .collect::>() }); - let shuffle_errors = - self.cs - .shuffles - .iter() - .enumerate() - .flat_map(|(shuffle_index, shuffle)| { - assert!(shuffle.shuffle_expressions.len() == shuffle.input_expressions.len()); - assert!(self.usable_rows.end > 0); - - let mut shuffle_rows: Vec>> = self - .usable_rows - .clone() - .map(|row| { - let t = shuffle - .shuffle_expressions - .iter() - .map(move |c| load(c, row)) - .collect(); - t - }) - .collect(); - shuffle_rows.sort(); - - let mut input_rows: Vec<(Vec>, usize)> = self - .usable_rows - .clone() - .map(|input_row| { - let t = shuffle - .input_expressions - .iter() - .map(move |c| load(c, input_row)) - .collect(); - - (t, input_row) - }) - .collect(); - input_rows.sort(); - - input_rows - .iter() - .zip(shuffle_rows.iter()) - .filter_map(|((input_value, row), shuffle_value)| { - if shuffle_value != input_value { - Some(VerifyFailure::Shuffle { - name: shuffle.name.clone(), - shuffle_index, - location: FailureLocation::find_expressions( - &self.cs, - &self.regions, - *row, - shuffle.input_expressions.iter(), - ), - }) - } else { - None - } - }) - .collect::>() - }); - let mapping = self.permutation.mapping(); // Check that permutations preserve the original values of the cells. let perm_errors = { @@ -1170,7 +1110,6 @@ impl + Ord> MockProver { .chain(gate_errors) .chain(lookup_errors) .chain(perm_errors) - .chain(shuffle_errors) .collect(); if errors.is_empty() { Ok(()) @@ -1521,66 +1460,6 @@ impl + Ord> MockProver { .collect::>() }); - let shuffle_errors = - self.cs - .shuffles - .iter() - .enumerate() - .flat_map(|(shuffle_index, shuffle)| { - assert!(shuffle.shuffle_expressions.len() == shuffle.input_expressions.len()); - assert!(self.usable_rows.end > 0); - - let mut shuffle_rows: Vec>> = self - .usable_rows - .clone() - .map(|row| { - let t = shuffle - .shuffle_expressions - .iter() - .map(move |c| load(c, row)) - .collect(); - t - }) - .collect(); - shuffle_rows.sort(); - - let mut input_rows: Vec<(Vec>, usize)> = self - .usable_rows - .clone() - .map(|input_row| { - let t = shuffle - .input_expressions - .iter() - .map(move |c| load(c, input_row)) - .collect(); - - (t, input_row) - }) - .collect(); - input_rows.sort(); - - input_rows - .iter() - .zip(shuffle_rows.iter()) - .filter_map(|((input_value, row), shuffle_value)| { - if shuffle_value != input_value { - Some(VerifyFailure::Shuffle { - name: shuffle.name.clone(), - shuffle_index, - location: FailureLocation::find_expressions( - &self.cs, - &self.regions, - *row, - shuffle.input_expressions.iter(), - ), - }) - } else { - None - } - }) - .collect::>() - }); - let mapping = self.permutation.mapping(); // Check that permutations preserve the original values of the cells. let perm_errors = { @@ -1634,7 +1513,6 @@ impl + Ord> MockProver { .chain(gate_errors) .chain(lookup_errors) .chain(perm_errors) - .chain(shuffle_errors) .collect(); if errors.is_empty() { Ok(()) diff --git a/halo2_proofs/src/dev/failure.rs b/halo2_proofs/src/dev/failure.rs index eb71531041..9f6ed44fa2 100644 --- a/halo2_proofs/src/dev/failure.rs +++ b/halo2_proofs/src/dev/failure.rs @@ -192,28 +192,6 @@ pub enum VerifyFailure { /// lookup is active on a row adjacent to an unrelated region. location: FailureLocation, }, - /// A shuffle input did not exist in its corresponding map. - Shuffle { - /// The name of the lookup that is not satisfied. - name: String, - /// The index of the lookup that is not satisfied. These indices are assigned in - /// the order in which `ConstraintSystem::lookup` is called during - /// `Circuit::configure`. - shuffle_index: usize, - /// The location at which the lookup is not satisfied. - /// - /// `FailureLocation::InRegion` is most common, and may be due to the intentional - /// use of a lookup (if its inputs are conditional on a complex selector), or an - /// unintentional lookup constraint that overlaps the region (indicating that the - /// lookup's inputs should be made conditional). - /// - /// `FailureLocation::OutsideRegion` is uncommon, and could mean that: - /// - The input expressions do not correctly constrain a default value that exists - /// in the table when the lookup is not being used. - /// - The input expressions use a column queried at a non-zero `Rotation`, and the - /// lookup is active on a row adjacent to an unrelated region. - location: FailureLocation, - }, /// A permutation did not preserve the original value of a cell. Permutation { /// The column in which this permutation is not satisfied. @@ -290,17 +268,6 @@ impl fmt::Display for VerifyFailure { name, lookup_index, location ) } - Self::Shuffle { - name, - shuffle_index, - location, - } => { - write!( - f, - "Shuffle {}(index: {}) is not satisfied {}", - name, shuffle_index, location - ) - } Self::Permutation { column, location } => { write!( f, @@ -684,184 +651,6 @@ fn render_lookup( } } -fn render_shuffle( - prover: &MockProver, - name: &str, - shuffle_index: usize, - location: &FailureLocation, -) { - let n = prover.n as i32; - let cs = &prover.cs; - let shuffle = &cs.shuffles[shuffle_index]; - - // Get the absolute row on which the shuffle's inputs are being queried, so we can - // fetch the input values. - let row = match location { - FailureLocation::InRegion { region, offset } => { - prover.regions[region.index].rows.unwrap().0 + offset - } - FailureLocation::OutsideRegion { row } => *row, - } as i32; - - let shuffle_columns = shuffle.shuffle_expressions.iter().map(|expr| { - expr.evaluate( - &|f| format! {"Const: {:#?}", f}, - &|s| format! {"S{}", s.0}, - &|query| { - format!( - "{:?}", - prover - .cs - .general_column_annotations - .get(&metadata::Column::from((Any::Fixed, query.column_index))) - .cloned() - .unwrap_or_else(|| format!("F{}", query.column_index())) - ) - }, - &|query| { - format!( - "{:?}", - prover - .cs - .general_column_annotations - .get(&metadata::Column::from((Any::advice(), query.column_index))) - .cloned() - .unwrap_or_else(|| format!("A{}", query.column_index())) - ) - }, - &|query| { - format!( - "{:?}", - prover - .cs - .general_column_annotations - .get(&metadata::Column::from((Any::Instance, query.column_index))) - .cloned() - .unwrap_or_else(|| format!("I{}", query.column_index())) - ) - }, - &|challenge| format! {"C{}", challenge.index()}, - &|query| format! {"-{}", query}, - &|a, b| format! {"{} + {}", a,b}, - &|a, b| format! {"{} * {}", a,b}, - &|a, b| format! {"{} * {:?}", a, b}, - ) - }); - - fn cell_value<'a, F: Field, Q: Into + Copy>( - load: impl Fn(Q) -> Value + 'a, - ) -> impl Fn(Q) -> BTreeMap + 'a { - move |query| { - let AnyQuery { - column_type, - column_index, - rotation, - .. - } = query.into(); - Some(( - ((column_type, column_index).into(), rotation.0).into(), - match load(query) { - Value::Real(v) => util::format_value(v), - Value::Poison => unreachable!(), - }, - )) - .into_iter() - .collect() - } - } - - eprintln!("error: input does not exist in shuffle"); - eprint!(" ("); - for i in 0..shuffle.input_expressions.len() { - eprint!("{}L{}", if i == 0 { "" } else { ", " }, i); - } - eprint!(") <-> ("); - for (i, column) in shuffle_columns.enumerate() { - eprint!("{}{}", if i == 0 { "" } else { ", " }, column); - } - eprintln!(")"); - - eprintln!(); - eprintln!(" Shuffle '{}' inputs:", name); - let advice = prover - .advice - .iter() - .map(|advice| { - advice - .iter() - .map(|rc| match rc { - AdviceCellValue::Assigned(a) => CellValue::Assigned(a.evaluate()), - AdviceCellValue::Poison(i) => CellValue::Poison(*i), - }) - .collect::>() - }) - .collect::>(); - for (i, input) in shuffle.input_expressions.iter().enumerate() { - // Fetch the cell values (since we don't store them in VerifyFailure::Shuffle). - let cell_values = input.evaluate( - &|_| BTreeMap::default(), - &|_| panic!("virtual selectors are removed during optimization"), - &cell_value(&util::load(n, row, &cs.fixed_queries, &prover.fixed)), - &cell_value(&util::load(n, row, &cs.advice_queries, &advice)), - &cell_value(&util::load_instance( - n, - row, - &cs.instance_queries, - &prover.instance, - )), - &|_| BTreeMap::default(), - &|a| a, - &|mut a, mut b| { - a.append(&mut b); - a - }, - &|mut a, mut b| { - a.append(&mut b); - a - }, - &|a, _| a, - ); - - // Collect the necessary rendering information: - // - The columns involved in this constraint. - // - How many cells are in each column. - // - The grid of cell values, indexed by rotation. - let mut columns = BTreeMap::::default(); - let mut layout = BTreeMap::>::default(); - for (i, (cell, _)) in cell_values.iter().enumerate() { - *columns.entry(cell.column).or_default() += 1; - layout - .entry(cell.rotation) - .or_default() - .entry(cell.column) - .or_insert(format!("x{}", i)); - } - - if i != 0 { - eprintln!(); - } - eprintln!( - " Sh{} = {}", - i, - emitter::expression_to_string(input, &layout) - ); - eprintln!(" ^"); - - emitter::render_cell_layout(" | ", location, &columns, &layout, |_, rotation| { - if rotation == 0 { - eprint!(" <--{{ Shuffle '{}' inputs queried here", name); - } - }); - - // Print the map from local variables to assigned values. - eprintln!(" |"); - eprintln!(" | Assigned cell values:"); - for (i, (_, value)) in cell_values.iter().enumerate() { - eprintln!(" | x{} = {}", i, value); - } - } -} - impl VerifyFailure { /// Emits this failure in pretty-printed format to stderr. pub(super) fn emit(&self, prover: &MockProver) { @@ -892,11 +681,6 @@ impl VerifyFailure { lookup_index, location, } => render_lookup(prover, name, *lookup_index, location), - Self::Shuffle { - name, - shuffle_index, - location, - } => render_shuffle(prover, name, *shuffle_index, location), _ => eprintln!("{}", self), } } diff --git a/halo2_proofs/src/fft.rs b/halo2_proofs/src/fft.rs new file mode 100644 index 0000000000..326fd9d318 --- /dev/null +++ b/halo2_proofs/src/fft.rs @@ -0,0 +1,119 @@ +//! This is a module for dispatching between different FFT implementations at runtime based on environment variable `FFT`. + +use ff::Field; + +use self::recursive::FFTData; +use crate::arithmetic::FftGroup; + +pub mod baseline; +pub mod parallel; +pub mod recursive; + +/// Runtime dispatcher to concrete FFT implementation +pub fn fft>( + a: &mut [G], + omega: Scalar, + log_n: u32, + data: &FFTData, + inverse: bool, +) { + recursive::fft(a, omega, log_n, data, inverse) +} + +#[cfg(test)] +mod tests { + use ark_std::{end_timer, start_timer}; + use ff::Field; + use halo2curves::bn256::Fr as Scalar; + use rand_core::OsRng; + + use crate::{arithmetic::best_fft, fft, multicore, poly::EvaluationDomain}; + + #[test] + fn test_fft_recursive() { + let k = 22; + + let domain = EvaluationDomain::::new(1, k); + let n = domain.get_n() as usize; + + let input = vec![Scalar::random(OsRng); n]; + + let num_threads = multicore::current_num_threads(); + + let mut a = input.clone(); + let l_a = a.len(); + let start = start_timer!(|| format!("best fft {} ({})", a.len(), num_threads)); + fft::baseline::fft( + &mut a, + domain.get_omega(), + k, + domain.get_fft_data(l_a), + false, + ); + end_timer!(start); + + let mut b = input; + let l_b = b.len(); + let start = start_timer!(|| format!("recursive fft {} ({})", a.len(), num_threads)); + fft::recursive::fft( + &mut b, + domain.get_omega(), + k, + domain.get_fft_data(l_b), + false, + ); + end_timer!(start); + + for i in 0..n { + //log_info(format!("{}: {} {}", i, a[i], b[i])); + assert_eq!(a[i], b[i]); + } + } + + #[test] + fn test_ifft_recursive() { + let k = 22; + + let domain = EvaluationDomain::::new(1, k); + let n = domain.get_n() as usize; + + let input = vec![Scalar::random(OsRng); n]; + + let mut a = input.clone(); + let l_a = a.len(); + fft::recursive::fft( + &mut a, + domain.get_omega(), + k, + domain.get_fft_data(l_a), + false, + ); + fft::recursive::fft( + &mut a, + domain.get_omega_inv(), // doesn't actually do anything + k, + domain.get_fft_data(l_a), + true, + ); + let ifft_divisor = Scalar::from(n as u64).invert().unwrap(); + + for i in 0..n { + assert_eq!(input[i], a[i] * ifft_divisor); + } + } + + #[test] + fn test_mem_leak() { + let j = 1; + let k = 3; + let domain = EvaluationDomain::new(j, k); + let omega = domain.get_omega(); + let l = 1 << k; + let data = domain.get_fft_data(l); + let mut a = (0..(1 << k)) + .map(|_| Scalar::random(OsRng)) + .collect::>(); + + best_fft(&mut a, omega, k, data, false); + } +} diff --git a/halo2_proofs/src/fft/baseline.rs b/halo2_proofs/src/fft/baseline.rs new file mode 100644 index 0000000000..8af165dc56 --- /dev/null +++ b/halo2_proofs/src/fft/baseline.rs @@ -0,0 +1,128 @@ +//! This contains the baseline FFT implementation + +use ff::Field; + +use super::recursive::FFTData; +use crate::{ + arithmetic::{self, log2_floor, FftGroup}, + multicore, +}; + +/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size +/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative +/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when +/// interpreted as the coefficients of a polynomial of degree $n - 1$, is +/// transformed into the evaluations of this polynomial at each of the $n$ +/// distinct powers of $\omega$. This transformation is invertible by providing +/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element +/// by $n$. +/// +/// This will use multithreading if beneficial. +fn best_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + let threads = multicore::current_num_threads(); + let log_threads = log2_floor(threads); + let n = a.len(); + assert_eq!(n, 1 << log_n); + + for k in 0..n { + let rk = arithmetic::bitreverse(k, log_n as usize); + if k < rk { + a.swap(rk, k); + } + } + + //let start = start_measure(format!("twiddles {} ({})", a.len(), threads), false); + // precompute twiddle factors + let twiddles: Vec<_> = (0..(n / 2)) + .scan(Scalar::ONE, |w, _| { + let tw = *w; + *w *= ω + Some(tw) + }) + .collect(); + //stop_measure(start); + + if log_n <= log_threads { + let mut chunk = 2_usize; + let mut twiddle_chunk = n / 2; + for _ in 0..log_n { + a.chunks_mut(chunk).for_each(|coeffs| { + let (left, right) = coeffs.split_at_mut(chunk / 2); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + }); + chunk *= 2; + twiddle_chunk /= 2; + } + } else { + recursive_butterfly_arithmetic(a, n, 1, &twiddles) + } +} + +/// This perform recursive butterfly arithmetic +fn recursive_butterfly_arithmetic>( + a: &mut [G], + n: usize, + twiddle_chunk: usize, + twiddles: &[Scalar], +) { + if n == 2 { + let t = a[1]; + a[1] = a[0]; + a[0] += &t; + a[1] -= &t; + } else { + let (left, right) = a.split_at_mut(n / 2); + multicore::join( + || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles), + || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles), + ); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + } +} + +/// Generic adaptor +pub fn fft>( + a: &mut [G], + omega: Scalar, + log_n: u32, + _data: &FFTData, + _inverse: bool, +) { + best_fft(a, omega, log_n) +} diff --git a/halo2_proofs/src/fft/parallel.rs b/halo2_proofs/src/fft/parallel.rs new file mode 100644 index 0000000000..014cb538d6 --- /dev/null +++ b/halo2_proofs/src/fft/parallel.rs @@ -0,0 +1,282 @@ +//! This module provides common utilities, traits and structures for group, +//! field and polynomial arithmetic. + +use crate::arithmetic::{self, log2_floor, FftGroup}; + +use crate::multicore; +pub use ff::Field; +pub use halo2curves::{CurveAffine, CurveExt}; + +use super::recursive::FFTData; + +/// A constant +pub const SPARSE_TWIDDLE_DEGREE: u32 = 10; + +/// Dispatcher +fn best_fft_opt>(a: &mut [G], omega: Scalar, log_n: u32) { + let threads = multicore::current_num_threads(); + let log_split = log2_floor(threads) as usize; + let n = a.len(); + let sub_n = n >> log_split; + let split_m = 1 << log_split; + + if sub_n >= split_m { + parallel_fft(a, omega, log_n); + } else { + serial_fft(a, omega, log_n); + } +} + +fn serial_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + let n = a.len() as u32; + assert_eq!(n, 1 << log_n); + + for k in 0..n as usize { + let rk = arithmetic::bitreverse(k, log_n as usize); + if k < rk { + a.swap(rk, k); + } + } + + let mut m = 1; + for _ in 0..log_n { + let w_m: Scalar = omega.pow_vartime([u64::from(n / (2 * m)), 0, 0, 0]); + + let mut k = 0; + while k < n { + let mut w = Scalar::ONE; + for j in 0..m { + let mut t = a[(k + j + m) as usize]; + t *= &w; + a[(k + j + m) as usize] = a[(k + j) as usize]; + a[(k + j + m) as usize] -= &t; + a[(k + j) as usize] += &t; + w *= &w_m; + } + + k += 2 * m; + } + + m *= 2; + } +} + +fn serial_split_fft>( + a: &mut [G], + twiddle_lut: &[Scalar], + twiddle_scale: usize, + log_n: u32, +) { + let n = a.len() as u32; + assert_eq!(n, 1 << log_n); + + let mut m = 1; + for _ in 0..log_n { + let omega_idx = twiddle_scale * n as usize / (2 * m as usize); // 1/2, 1/4, 1/8, ... + let low_idx = omega_idx % (1 << SPARSE_TWIDDLE_DEGREE); + let high_idx = omega_idx >> SPARSE_TWIDDLE_DEGREE; + let mut w_m = twiddle_lut[low_idx]; + if high_idx > 0 { + w_m *= twiddle_lut[(1 << SPARSE_TWIDDLE_DEGREE) + high_idx]; + } + + let mut k = 0; + while k < n { + let mut w = Scalar::ONE; + for j in 0..m { + let mut t = a[(k + j + m) as usize]; + t *= &w; + a[(k + j + m) as usize] = a[(k + j) as usize]; + a[(k + j + m) as usize] -= &t; + a[(k + j) as usize] += &t; + w *= &w_m; + } + + k += 2 * m; + } + + m *= 2; + } +} + +fn split_radix_fft>( + tmp: &mut [G], + a: &[G], + twiddle_lut: &[Scalar], + n: usize, + sub_fft_offset: usize, + log_split: usize, +) { + let split_m = 1 << log_split; + let sub_n = n >> log_split; + + // we use out-place bitreverse here, split_m <= num_threads, so the buffer spase is small + // and it's is good for data locality + let tmp_filler_val = tmp[0]; + let mut t1 = vec![tmp_filler_val; split_m]; + for i in 0..split_m { + t1[arithmetic::bitreverse(i, log_split)] = a[i * sub_n + sub_fft_offset]; + } + serial_split_fft(&mut t1, twiddle_lut, sub_n, log_split as u32); + + let sparse_degree = SPARSE_TWIDDLE_DEGREE; + let omega_idx = sub_fft_offset; + let low_idx = omega_idx % (1 << sparse_degree); + let high_idx = omega_idx >> sparse_degree; + let mut omega = twiddle_lut[low_idx]; + if high_idx > 0 { + omega *= twiddle_lut[(1 << sparse_degree) + high_idx]; + } + let mut w_m = Scalar::ONE; + for i in 0..split_m { + t1[i] *= &w_m; + tmp[i] = t1[i]; + w_m *= omega; + } +} + +/// Precalculate twiddles factors +fn generate_twiddle_lookup_table( + omega: F, + log_n: u32, + sparse_degree: u32, + with_last_level: bool, +) -> Vec { + let without_last_level = !with_last_level; + let is_lut_len_large = sparse_degree > log_n; + + // dense + if is_lut_len_large { + let mut twiddle_lut = vec![F::ZERO; (1 << log_n) as usize]; + parallelize(&mut twiddle_lut, |twiddle_lut, start| { + let mut w_n = omega.pow_vartime([start as u64, 0, 0, 0]); + for twiddle_lut in twiddle_lut.iter_mut() { + *twiddle_lut = w_n; + w_n *= omega; + } + }); + return twiddle_lut; + } + + // sparse + let low_degree_lut_len = 1 << sparse_degree; + let high_degree_lut_len = 1 << (log_n - sparse_degree - without_last_level as u32); + let mut twiddle_lut = vec![F::ZERO; low_degree_lut_len + high_degree_lut_len]; + parallelize( + &mut twiddle_lut[..low_degree_lut_len], + |twiddle_lut, start| { + let mut w_n = omega.pow_vartime([start as u64, 0, 0, 0]); + for twiddle_lut in twiddle_lut.iter_mut() { + *twiddle_lut = w_n; + w_n *= omega; + } + }, + ); + let high_degree_omega = omega.pow_vartime([(1 << sparse_degree) as u64, 0, 0, 0]); + parallelize( + &mut twiddle_lut[low_degree_lut_len..], + |twiddle_lut, start| { + let mut w_n = high_degree_omega.pow_vartime([start as u64, 0, 0, 0]); + for twiddle_lut in twiddle_lut.iter_mut() { + *twiddle_lut = w_n; + w_n *= high_degree_omega; + } + }, + ); + twiddle_lut +} + +/// The parallel implementation +fn parallel_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + let n = a.len(); + assert_eq!(n, 1 << log_n); + + let log_split = log2_floor(multicore::current_num_threads()) as usize; + let split_m = 1 << log_split; + let sub_n = n >> log_split; + let twiddle_lut = generate_twiddle_lookup_table(omega, log_n, SPARSE_TWIDDLE_DEGREE, true); + + // split fft + let tmp_filler_val = a[0]; + let mut tmp = vec![tmp_filler_val; n]; + multicore::scope(|scope| { + let a = &*a; + let twiddle_lut = &*twiddle_lut; + for (chunk_idx, tmp) in tmp.chunks_mut(sub_n).enumerate() { + scope.spawn(move |_| { + let split_fft_offset = (chunk_idx * sub_n) >> log_split; + for (i, tmp) in tmp.chunks_mut(split_m).enumerate() { + let split_fft_offset = split_fft_offset + i; + split_radix_fft(tmp, a, twiddle_lut, n, split_fft_offset, log_split); + } + }); + } + }); + + // shuffle + parallelize(a, |a, start| { + for (idx, a) in a.iter_mut().enumerate() { + let idx = start + idx; + let i = idx / sub_n; + let j = idx % sub_n; + *a = tmp[j * split_m + i]; + } + }); + + // sub fft + let new_omega = omega.pow_vartime([split_m as u64, 0, 0, 0]); + multicore::scope(|scope| { + for a in a.chunks_mut(sub_n) { + scope.spawn(move |_| { + serial_fft(a, new_omega, log_n - log_split as u32); + }); + } + }); + + // copy & unshuffle + let mask = (1 << log_split) - 1; + parallelize(&mut tmp, |tmp, start| { + for (idx, tmp) in tmp.iter_mut().enumerate() { + let idx = start + idx; + *tmp = a[idx]; + } + }); + parallelize(a, |a, start| { + for (idx, a) in a.iter_mut().enumerate() { + let idx = start + idx; + *a = tmp[sub_n * (idx & mask) + (idx >> log_split)]; + } + }); +} + +/// This simple utility function will parallelize an operation that is to be +/// performed over a mutable slice. +fn parallelize(v: &mut [T], f: F) { + let n = v.len(); + let num_threads = multicore::current_num_threads(); + let mut chunk = n / num_threads; + if chunk < num_threads { + chunk = n; + } + + multicore::scope(|scope| { + for (chunk_num, v) in v.chunks_mut(chunk).enumerate() { + let f = f.clone(); + scope.spawn(move |_| { + let start = chunk_num * chunk; + f(v, start); + }); + } + }); +} + +/// Generic adaptor +pub fn fft>( + data_in: &mut [G], + omega: Scalar, + log_n: u32, + _data: &FFTData, + _inverse: bool, +) { + best_fft_opt(data_in, omega, log_n) +} diff --git a/halo2_proofs/src/fft/recursive.rs b/halo2_proofs/src/fft/recursive.rs new file mode 100644 index 0000000000..7973336c49 --- /dev/null +++ b/halo2_proofs/src/fft/recursive.rs @@ -0,0 +1,461 @@ +//! This contains the recursive FFT. + +use crate::{ + arithmetic::{self, parallelize, FftGroup}, + multicore, +}; + +pub use ff::Field; +pub use halo2curves::{CurveAffine, CurveExt}; + +/// FFTStage +#[derive(Clone, Debug)] +pub struct FFTStage { + radix: usize, + length: usize, +} + +/// FFT stages +fn get_stages(size: usize, radixes: Vec) -> Vec { + let mut stages: Vec = vec![]; + + let mut n = size; + + // Use the specified radices + for &radix in &radixes { + n /= radix; + stages.push(FFTStage { radix, length: n }); + } + + // Fill in the rest of the tree if needed + let mut p = 2; + while n > 1 { + while n % p != 0 { + if p == 4 { + p = 2; + } + } + n /= p; + stages.push(FFTStage { + radix: p, + length: n, + }); + } + + /*for i in 0..stages.len() { + log_info(format!("Stage {}: {}, {}", i, stages[i].radix, stages[i].length)); + }*/ + + stages +} + +/// FFTData +#[derive(Clone, Debug)] +pub struct FFTData { + n: usize, + + stages: Vec, + + f_twiddles: Vec>, + inv_twiddles: Vec>, + //scratch: Vec, +} + +impl Default for FFTData { + fn default() -> Self { + Self { + n: Default::default(), + stages: Default::default(), + f_twiddles: Default::default(), + inv_twiddles: Default::default(), + } + } +} + +impl FFTData { + /// Create FFT data + pub fn new(n: usize, omega: F, omega_inv: F) -> Self { + let stages = get_stages(n, vec![]); + let mut f_twiddles = vec![]; + let mut inv_twiddles = vec![]; + let mut scratch = vec![F::ZERO; n]; + + // Generate stage twiddles + for inv in 0..2 { + let inverse = inv == 0; + let o = if inverse { omega_inv } else { omega }; + let stage_twiddles = if inverse { + &mut inv_twiddles + } else { + &mut f_twiddles + }; + + let twiddles = &mut scratch; + + // Twiddles + parallelize(twiddles, |twiddles, start| { + let w_m = o; + let mut w = o.pow_vartime([start as u64, 0, 0, 0]); + for value in twiddles.iter_mut() { + *value = w; + w *= w_m; + } + }); + + // Re-order twiddles for cache friendliness + let num_stages = stages.len(); + stage_twiddles.resize(num_stages, vec![]); + for l in 0..num_stages { + let radix = stages[l].radix; + let stage_length = stages[l].length; + + let num_twiddles = stage_length * (radix - 1); + stage_twiddles[l].resize(num_twiddles + 1, F::ZERO); + + // Set j + stage_twiddles[l][num_twiddles] = twiddles[(twiddles.len() * 3) / 4]; + + let stride = n / (stage_length * radix); + let mut tws = vec![0usize; radix - 1]; + for i in 0..stage_length { + for j in 0..radix - 1 { + stage_twiddles[l][i * (radix - 1) + j] = twiddles[tws[j]]; + tws[j] += (j + 1) * stride; + } + } + } + } + + Self { + n, + stages, + f_twiddles, + inv_twiddles, + //scratch, + } + } + + /// Return private field `n` + pub fn get_n(&self) -> usize { + self.n + } +} + +/// Radix 2 butterfly +fn butterfly_2>( + out: &mut [G], + twiddles: &[Scalar], + stage_length: usize, +) { + let mut out_offset = 0; + let mut out_offset2 = stage_length; + + let t = out[out_offset2]; + out[out_offset2] = out[out_offset] - &t; + out[out_offset] += &t; + out_offset2 += 1; + out_offset += 1; + + for twiddle in twiddles[1..stage_length].iter() { + let t = out[out_offset2] * twiddle; + out[out_offset2] = out[out_offset] - &t; + out[out_offset] += &t; + out_offset2 += 1; + out_offset += 1; + } +} + +/// Radix 2 butterfly +fn butterfly_2_parallel>( + out: &mut [G], + twiddles: &[Scalar], + _stage_length: usize, + num_threads: usize, +) { + let n = out.len(); + let mut chunk = n / num_threads; + if chunk < num_threads { + chunk = n; + } + + multicore::scope(|scope| { + let (part_a, part_b) = out.split_at_mut(n / 2); + for (i, (part0, part1)) in part_a + .chunks_mut(chunk) + .zip(part_b.chunks_mut(chunk)) + .enumerate() + { + scope.spawn(move |_| { + let offset = i * chunk; + for k in 0..part0.len() { + let t = part1[k] * &twiddles[offset + k]; + part1[k] = part0[k] - &t; + part0[k] += &t; + } + }); + } + }); +} + +/// Radix 4 butterfly +fn butterfly_4>( + out: &mut [G], + twiddles: &[Scalar], + stage_length: usize, +) { + let j = twiddles[twiddles.len() - 1]; + let mut tw = 0; + + /* Case twiddle == one */ + { + let i0 = 0; + let i1 = stage_length; + let i2 = stage_length * 2; + let i3 = stage_length * 3; + + let z0 = out[i0]; + let z1 = out[i1]; + let z2 = out[i2]; + let z3 = out[i3]; + + let t1 = z0 + &z2; + let t2 = z1 + &z3; + let t3 = z0 - &z2; + let t4j = (z1 - &z3) * &j; + + out[i0] = t1 + &t2; + out[i1] = t3 - &t4j; + out[i2] = t1 - &t2; + out[i3] = t3 + &t4j; + + tw += 3; + } + + for k in 1..stage_length { + let i0 = k; + let i1 = k + stage_length; + let i2 = k + stage_length * 2; + let i3 = k + stage_length * 3; + + let z0 = out[i0]; + let z1 = out[i1] * &twiddles[tw]; + let z2 = out[i2] * &twiddles[tw + 1]; + let z3 = out[i3] * &twiddles[tw + 2]; + + let t1 = z0 + &z2; + let t2 = z1 + &z3; + let t3 = z0 - &z2; + let t4j = (z1 - &z3) * &j; + + out[i0] = t1 + &t2; + out[i1] = t3 - &t4j; + out[i2] = t1 - &t2; + out[i3] = t3 + &t4j; + + tw += 3; + } +} + +/// Radix 4 butterfly +fn butterfly_4_parallel>( + out: &mut [G], + twiddles: &[Scalar], + _stage_length: usize, + num_threads: usize, +) { + let j = twiddles[twiddles.len() - 1]; + + let n = out.len(); + let mut chunk = n / num_threads; + if chunk < num_threads { + chunk = n; + } + multicore::scope(|scope| { + //let mut parts: Vec<&mut [F]> = out.chunks_mut(4).collect(); + //out.chunks_mut(4).map(|c| c.chunks_mut(chunk)).fold(predicate) + let (part_a, part_b) = out.split_at_mut(n / 2); + let (part_aa, part_ab) = part_a.split_at_mut(n / 4); + let (part_ba, part_bb) = part_b.split_at_mut(n / 4); + for (i, (((part0, part1), part2), part3)) in part_aa + .chunks_mut(chunk) + .zip(part_ab.chunks_mut(chunk)) + .zip(part_ba.chunks_mut(chunk)) + .zip(part_bb.chunks_mut(chunk)) + .enumerate() + { + scope.spawn(move |_| { + let offset = i * chunk; + let mut tw = offset * 3; + for k in 0..part1.len() { + let z0 = part0[k]; + let z1 = part1[k] * &twiddles[tw]; + let z2 = part2[k] * &twiddles[tw + 1]; + let z3 = part3[k] * &twiddles[tw + 2]; + + let t1 = z0 + &z2; + let t2 = z1 + &z3; + let t3 = z0 - &z2; + let t4j = (z1 - &z3) * &j; + + part0[k] = t1 + &t2; + part1[k] = t3 - &t4j; + part2[k] = t1 - &t2; + part3[k] = t3 + &t4j; + + tw += 3; + } + }); + } + }); +} + +/// Inner recursion +#[allow(clippy::too_many_arguments)] +fn recursive_fft_inner>( + data_in: &[G], + data_out: &mut [G], + twiddles: &Vec>, + stages: &Vec, + in_offset: usize, + stride: usize, + level: usize, + num_threads: usize, +) { + let radix = stages[level].radix; + let stage_length = stages[level].length; + + if num_threads > 1 { + if stage_length == 1 { + for i in 0..radix { + data_out[i] = data_in[in_offset + i * stride]; + } + } else { + let num_threads_recursive = if num_threads >= radix { + radix + } else { + num_threads + }; + parallelize_count(data_out, num_threads_recursive, |data_out, i| { + let num_threads_in_recursion = if num_threads < radix { + 1 + } else { + (num_threads + i) / radix + }; + recursive_fft_inner( + data_in, + data_out, + twiddles, + stages, + in_offset + i * stride, + stride * radix, + level + 1, + num_threads_in_recursion, + ) + }); + } + match radix { + 2 => butterfly_2_parallel(data_out, &twiddles[level], stage_length, num_threads), + 4 => butterfly_4_parallel(data_out, &twiddles[level], stage_length, num_threads), + _ => unimplemented!("radix unsupported"), + } + } else { + if stage_length == 1 { + for i in 0..radix { + data_out[i] = data_in[in_offset + i * stride]; + } + } else { + for i in 0..radix { + recursive_fft_inner( + data_in, + &mut data_out[i * stage_length..(i + 1) * stage_length], + twiddles, + stages, + in_offset + i * stride, + stride * radix, + level + 1, + num_threads, + ); + } + } + match radix { + 2 => butterfly_2(data_out, &twiddles[level], stage_length), + 4 => butterfly_4(data_out, &twiddles[level], stage_length), + _ => unimplemented!("radix unsupported"), + } + } +} + +/// Todo: Brechts impl starts here +fn recursive_fft>( + data: &FFTData, + data_in: &mut Vec, + inverse: bool, +) { + let num_threads = multicore::current_num_threads(); + //let start = start_measure(format!("recursive fft {} ({})", data_in.len(), num_threads), false); + + // TODO: reuse scratch buffer between FFTs + //let start_mem = start_measure(format!("alloc"), false); + let filler = data_in[0]; + let mut scratch = vec![filler; data_in.len()]; + //stop_measure(start_mem); + + recursive_fft_inner( + data_in, + &mut /*data.*/scratch, + if inverse { + &data.inv_twiddles + } else { + &data.f_twiddles + }, + &data.stages, + 0, + 1, + 0, + num_threads, + ); + //let duration = stop_measure(start); + + //let start = start_measure(format!("copy"), false); + // Will simply swap the vector's buffer, no data is actually copied + std::mem::swap(data_in, &mut /*data.*/scratch); + //stop_measure(start); +} + +/// This simple utility function will parallelize an operation that is to be +/// performed over a mutable slice. +fn parallelize_count( + v: &mut [T], + num_threads: usize, + f: F, +) { + let n = v.len(); + let mut chunk = n / num_threads; + if chunk < num_threads { + chunk = n; + } + + multicore::scope(|scope| { + for (chunk_num, v) in v.chunks_mut(chunk).enumerate() { + let f = f.clone(); + scope.spawn(move |_| { + f(v, chunk_num); + }); + } + }); +} + +/// Generic adaptor +pub fn fft>( + data_in: &mut [G], + _omega: Scalar, + _log_n: u32, + data: &FFTData, + inverse: bool, +) { + let orig_len = data_in.len(); + let mut data_in_vec = data_in.to_vec(); + recursive_fft(data, &mut data_in_vec, inverse); + data_in.copy_from_slice(&data_in_vec); + assert_eq!(orig_len, data_in.len()); +} diff --git a/halo2_proofs/src/lib.rs b/halo2_proofs/src/lib.rs index bd1d38d636..6de9087883 100644 --- a/halo2_proofs/src/lib.rs +++ b/halo2_proofs/src/lib.rs @@ -12,6 +12,7 @@ pub mod arithmetic; pub mod circuit; pub use halo2curves; +pub mod fft; mod multicore; pub mod plonk; pub mod poly; diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index ee77b6569d..69e469527a 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -13,10 +13,7 @@ use crate::helpers::{ polynomial_slice_byte_length, read_polynomial_vec, write_polynomial_slice, SerdeCurveAffine, SerdePrimeField, }; -use crate::poly::{ - Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff, PinnedEvaluationDomain, - Polynomial, -}; +use crate::poly::{Coeff, EvaluationDomain, LagrangeCoeff, PinnedEvaluationDomain, Polynomial}; use crate::transcript::{ChallengeScalar, EncodedChallenge, Transcript}; use crate::SerdeFormat; @@ -27,7 +24,7 @@ mod evaluation; mod keygen; mod lookup; pub mod permutation; -mod shuffle; +// mod shuffle; mod vanishing; mod prover; @@ -317,12 +314,11 @@ pub struct PinnedVerificationKey<'a, C: CurveAffine> { #[derive(Clone, Debug)] pub struct ProvingKey { vk: VerifyingKey, - l0: Polynomial, - l_last: Polynomial, - l_active_row: Polynomial, + l0: Polynomial, + l_last: Polynomial, + l_active_row: Polynomial, fixed_values: Vec>, fixed_polys: Vec>, - fixed_cosets: Vec>, permutation: permutation::ProvingKey, ev: Evaluator, } @@ -344,7 +340,6 @@ where + scalar_len * (self.l0.len() + self.l_last.len() + self.l_active_row.len()) + polynomial_slice_byte_length(&self.fixed_values) + polynomial_slice_byte_length(&self.fixed_polys) - + polynomial_slice_byte_length(&self.fixed_cosets) + self.permutation.bytes_length() } } @@ -364,13 +359,12 @@ where /// WITHOUT performing the expensive Montgomery reduction. /// Does so by first writing the verifying key and then serializing the rest of the data (in the form of field polynomials) pub fn write(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> { - self.vk.write(writer, format).unwrap(); + self.vk.write(writer, format)?; self.l0.write(writer, format); self.l_last.write(writer, format); self.l_active_row.write(writer, format); write_polynomial_slice(&self.fixed_values, writer, format); write_polynomial_slice(&self.fixed_polys, writer, format); - write_polynomial_slice(&self.fixed_cosets, writer, format); self.permutation.write(writer, format); Ok(()) } @@ -396,14 +390,12 @@ where format, #[cfg(feature = "circuit-params")] params, - ) - .unwrap(); + )?; let l0 = Polynomial::read(reader, format); let l_last = Polynomial::read(reader, format); let l_active_row = Polynomial::read(reader, format); let fixed_values = read_polynomial_vec(reader, format); let fixed_polys = read_polynomial_vec(reader, format); - let fixed_cosets = read_polynomial_vec(reader, format); let permutation = permutation::ProvingKey::read(reader, format); let ev = Evaluator::new(vk.cs()); Ok(Self { @@ -413,7 +405,6 @@ where l_active_row, fixed_values, fixed_polys, - fixed_cosets, permutation, ev, }) diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index 93897e8697..19443aa6c1 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -1,4 +1,4 @@ -use super::{lookup, permutation, shuffle, Assigned, Error}; +use super::{lookup, permutation, Assigned, Error}; use crate::circuit::layouter::SyncDeps; use crate::dev::metadata; use crate::{ @@ -8,6 +8,7 @@ use crate::{ use core::cmp::max; use core::ops::{Add, Mul}; use ff::Field; +use itertools::Itertools; use sealed::SealedPhase; use std::collections::HashMap; use std::env::var; @@ -1247,6 +1248,72 @@ impl Expression { &|a, _| a, ) } + + /// Extracts all used instance columns in this expression + pub fn extract_instances(&self) -> Vec { + self.evaluate( + &|_| vec![], + &|_| vec![], + &|_| vec![], + &|_| vec![], + &|query| vec![query.column_index], + &|_| vec![], + &|a| a, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|a, _| a, + ) + } + + /// Extracts all used advice columns in this expression + pub fn extract_advices(&self) -> Vec { + self.evaluate( + &|_| vec![], + &|_| vec![], + &|_| vec![], + &|query| vec![query.column_index], + &|_| vec![], + &|_| vec![], + &|a| a, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|a, _| a, + ) + } + + /// Extracts all used fixed columns in this expression + pub fn extract_fixed(&self) -> Vec { + self.evaluate( + &|_| vec![], + &|_| vec![], + &|query| vec![query.column_index], + &|_| vec![], + &|_| vec![], + &|_| vec![], + &|a| a, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|a, _| a, + ) + } } impl std::fmt::Debug for Expression { @@ -1563,10 +1630,6 @@ pub struct ConstraintSystem { // input expressions and a sequence of table expressions involved in the lookup. pub(crate) lookups: Vec>, - // Vector of shuffle arguments, where each corresponds to a sequence of - // input expressions and a sequence of shuffle expressions involved in the shuffle. - pub(crate) shuffles: Vec>, - // List of indexes of Fixed columns which are associated to a circuit-general Column tied to their annotation. pub(crate) general_column_annotations: HashMap, @@ -1593,7 +1656,6 @@ pub struct PinnedConstraintSystem<'a, F: Field> { fixed_queries: &'a Vec<(Column, Rotation)>, permutation: &'a permutation::Argument, lookups: &'a Vec>, - shuffles: &'a Vec>, constants: &'a Vec>, minimum_degree: &'a Option, } @@ -1654,7 +1716,6 @@ impl Default for ConstraintSystem { instance_queries: Vec::new(), permutation: permutation::Argument::new(), lookups: Vec::new(), - shuffles: Vec::new(), general_column_annotations: HashMap::new(), constants: vec![], minimum_degree: None, @@ -1681,7 +1742,6 @@ impl ConstraintSystem { instance_queries: &self.instance_queries, permutation: &self.permutation, lookups: &self.lookups, - shuffles: &self.shuffles, constants: &self.constants, minimum_degree: &self.minimum_degree, } @@ -1762,29 +1822,6 @@ impl ConstraintSystem { index } - /// Add a shuffle argument for some input expressions and table expressions. - pub fn shuffle>( - &mut self, - name: S, - shuffle_map: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec<(Expression, Expression)>, - ) -> usize { - let mut cells = VirtualCells::new(self); - let shuffle_map = shuffle_map(&mut cells) - .into_iter() - .map(|(mut input, mut table)| { - input.query_cells(&mut cells); - table.query_cells(&mut cells); - (input, table) - }) - .collect(); - let index = self.shuffles.len(); - - self.shuffles - .push(shuffle::Argument::new(name.as_ref(), shuffle_map)); - - index - } - fn query_fixed_index(&mut self, column: Column, at: Rotation) -> usize { // Return existing query, if it exists for (index, fixed_query) in self.fixed_queries.iter().enumerate() { @@ -2083,15 +2120,6 @@ impl ConstraintSystem { }) { replace_selectors(expr, selector_replacements, true); } - - for expr in self.shuffles.iter_mut().flat_map(|shuffle| { - shuffle - .input_expressions - .iter_mut() - .chain(shuffle.shuffle_expressions.iter_mut()) - }) { - replace_selectors(expr, selector_replacements, true); - } } /// Allocate a new (simple) selector. Simple selectors cannot be added to @@ -2252,17 +2280,6 @@ impl ConstraintSystem { .unwrap_or(1), ); - // The lookup argument also serves alongside the gates and must be accounted - // for. - degree = std::cmp::max( - degree, - self.shuffles - .iter() - .map(|l| l.required_degree()) - .max() - .unwrap_or(1), - ); - // Account for each gate to ensure our quotient polynomial is the // correct degree and that our extended domain is the right size. degree = std::cmp::max( @@ -2399,11 +2416,6 @@ impl ConstraintSystem { &self.lookups } - /// Returns shuffle arguments - pub fn shuffles(&self) -> &Vec> { - &self.shuffles - } - /// Returns constants pub fn constants(&self) -> &Vec> { &self.constants diff --git a/halo2_proofs/src/plonk/evaluation.rs b/halo2_proofs/src/plonk/evaluation.rs index c5a4977be4..64de1bffa0 100644 --- a/halo2_proofs/src/plonk/evaluation.rs +++ b/halo2_proofs/src/plonk/evaluation.rs @@ -3,11 +3,12 @@ use crate::plonk::{lookup, permutation, Any, ProvingKey}; use crate::poly::Basis; use crate::{ arithmetic::{parallelize, CurveAffine}, - poly::{Coeff, ExtendedLagrangeCoeff, Polynomial, Rotation}, + poly::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation}, }; use group::ff::{Field, PrimeField, WithSmallOrderMulGroup}; +use itertools::Itertools; -use super::{shuffle, ConstraintSystem, Expression}; +use super::{ConstraintSystem, Expression}; /// Return the index in the polynomial of size `isize` after rotation `rot`. fn get_rotation_idx(idx: usize, rot: i32, rot_scale: i32, isize: i32) -> usize { @@ -36,9 +37,7 @@ pub enum ValueSource { /// theta Theta(), /// y - Y(), - /// Previous value - PreviousValue(), + Y(usize), } impl Default for ValueSource { @@ -55,34 +54,35 @@ impl ValueSource { rotations: &[usize], constants: &[F], intermediates: &[F], - fixed_values: &[Polynomial], - advice_values: &[Polynomial], - instance_values: &[Polynomial], + fixed_values: &[Option>], + advice_values: &[Option>], + instance_values: &[Option>], challenges: &[F], + y_powers: &[F], beta: &F, gamma: &F, theta: &F, - y: &F, - previous_value: &F, ) -> F { match self { ValueSource::Constant(idx) => constants[*idx], ValueSource::Intermediate(idx) => intermediates[*idx], ValueSource::Fixed(column_index, rotation) => { - fixed_values[*column_index][rotations[*rotation]] + assert!(fixed_values[*column_index].is_some()); + fixed_values[*column_index].as_ref().unwrap()[rotations[*rotation]] } ValueSource::Advice(column_index, rotation) => { - advice_values[*column_index][rotations[*rotation]] + assert!(advice_values[*column_index].is_some()); + advice_values[*column_index].as_ref().unwrap()[rotations[*rotation]] } ValueSource::Instance(column_index, rotation) => { - instance_values[*column_index][rotations[*rotation]] + assert!(instance_values[*column_index].is_some()); + instance_values[*column_index].as_ref().unwrap()[rotations[*rotation]] } ValueSource::Challenge(index) => challenges[*index], ValueSource::Beta() => *beta, ValueSource::Gamma() => *gamma, ValueSource::Theta() => *theta, - ValueSource::Y() => *y, - ValueSource::PreviousValue() => *previous_value, + ValueSource::Y(idx) => y_powers[*idx], } } } @@ -116,15 +116,14 @@ impl Calculation { rotations: &[usize], constants: &[F], intermediates: &[F], - fixed_values: &[Polynomial], - advice_values: &[Polynomial], - instance_values: &[Polynomial], + fixed_values: &[Option>], + advice_values: &[Option>], + instance_values: &[Option>], challenges: &[F], + y_powers: &[F], beta: &F, gamma: &F, theta: &F, - y: &F, - previous_value: &F, ) -> F { let get_value = |value: &ValueSource| { value.get( @@ -135,11 +134,10 @@ impl Calculation { advice_values, instance_values, challenges, + y_powers, beta, gamma, theta, - y, - previous_value, ) }; match self { @@ -162,15 +160,41 @@ impl Calculation { } } +#[derive(Clone, Default, Debug)] +struct ConstraintCluster { + /// Used fixed columns in each cluster + used_fixed_columns: Vec, + /// Used instance columns in each cluster + used_instance_columns: Vec, + /// Used advice columns in each cluster + used_advice_columns: Vec, + /// Custom gates evalution + evaluator: GraphEvaluator, + /// The first index of constraints are being evaluated at in each cluster + first_constraint_idx: usize, + /// The last index of constraints are being evaluated at in each cluster + last_constraint_idx: usize, + /// The last value source + last_value_source: Option, +} + /// Evaluator #[derive(Clone, Default, Debug)] pub struct Evaluator { - /// Custom gates evalution - pub custom_gates: GraphEvaluator, - /// Lookups evalution - pub lookups: Vec>, - /// Shuffle evalution - pub shuffles: Vec>, + /// list of constraint clusters + custom_gate_clusters: Vec>, + /// Number of custom gate constraints + num_custom_gate_constraints: usize, + /// Lookups evalution, degree, used instance and advice columns + #[allow(clippy::type_complexity)] + lookups: Vec<( + GraphEvaluator, + usize, + (Vec, Vec, Vec), + )>, + + /// Powers of y + num_y_powers: usize, } /// GraphEvaluator @@ -204,46 +228,122 @@ pub struct CalculationInfo { pub target: usize, } +fn merge_unique(a: Vec, b: Vec) -> Vec { + let mut result = a; + result.extend(b); + result.into_iter().unique().collect() +} + impl Evaluator { /// Creates a new evaluation structure pub fn new(cs: &ConstraintSystem) -> Self { let mut ev = Evaluator::default(); + let mut constraint_idx = 0; + + // Compute the max cluster index + let quotient_poly_degree = (cs.degree() - 1) as u64; + let mut max_cluster_idx = 0; + while (1 << max_cluster_idx) < quotient_poly_degree { + max_cluster_idx += 1; + } + + ev.custom_gate_clusters + .resize(max_cluster_idx + 1, ConstraintCluster::default()); // Custom gates - let mut parts = Vec::new(); for gate in cs.gates.iter() { - parts.extend( - gate.polynomials() - .iter() - .map(|poly| ev.custom_gates.add_expression(poly)), - ); + for poly in gate.polynomials() { + constraint_idx += 1; + let cluster_idx = Self::compute_cluster_idx(poly.degree(), max_cluster_idx); + let custom_gate_cluster = &mut ev.custom_gate_clusters[cluster_idx]; + custom_gate_cluster.used_fixed_columns = merge_unique( + custom_gate_cluster.used_fixed_columns.clone(), + poly.extract_fixed(), + ); + custom_gate_cluster.used_instance_columns = merge_unique( + custom_gate_cluster.used_instance_columns.clone(), + poly.extract_instances(), + ); + custom_gate_cluster.used_advice_columns = merge_unique( + custom_gate_cluster.used_advice_columns.clone(), + poly.extract_advices(), + ); + let curr = custom_gate_cluster.evaluator.add_expression(poly); + if let Some(last) = custom_gate_cluster.last_value_source { + custom_gate_cluster.last_value_source = Some( + custom_gate_cluster + .evaluator + .add_calculation(Calculation::Horner( + last, + vec![curr], + ValueSource::Y( + constraint_idx - custom_gate_cluster.last_constraint_idx, + ), + )), + ); + } else { + assert_eq!(custom_gate_cluster.last_constraint_idx, 0); + custom_gate_cluster.last_value_source = Some(curr); + custom_gate_cluster.first_constraint_idx = constraint_idx; + } + custom_gate_cluster.last_constraint_idx = constraint_idx; + } } - ev.custom_gates.add_calculation(Calculation::Horner( - ValueSource::PreviousValue(), - parts, - ValueSource::Y(), - )); + + ev.num_custom_gate_constraints = constraint_idx; // Lookups for lookup in cs.lookups.iter() { + constraint_idx += 5; let mut graph = GraphEvaluator::default(); let mut evaluate_lc = |expressions: &Vec>| { + let mut max_degree = 0; + let mut used_fixed_columns = vec![]; + let mut used_instance_columns = vec![]; + let mut used_advice_columns = vec![]; let parts = expressions .iter() - .map(|expr| graph.add_expression(expr)) + .map(|expr| { + max_degree = max_degree.max(expr.degree()); + used_fixed_columns = + merge_unique(used_fixed_columns.clone(), expr.extract_fixed()); + used_instance_columns = + merge_unique(used_instance_columns.clone(), expr.extract_instances()); + used_advice_columns = + merge_unique(used_advice_columns.clone(), expr.extract_advices()); + graph.add_expression(expr) + }) .collect(); - graph.add_calculation(Calculation::Horner( - ValueSource::Constant(0), - parts, - ValueSource::Theta(), - )) + ( + graph.add_calculation(Calculation::Horner( + ValueSource::Constant(0), + parts, + ValueSource::Theta(), + )), + max_degree, + used_fixed_columns, + used_instance_columns, + used_advice_columns, + ) }; // Input coset - let compressed_input_coset = evaluate_lc(&lookup.input_expressions); + let ( + compressed_input_coset, + max_input_degree, + input_used_fixed, + input_used_instances, + input_used_advices, + ) = evaluate_lc(&lookup.input_expressions); // table coset - let compressed_table_coset = evaluate_lc(&lookup.table_expressions); + let ( + compressed_table_coset, + max_table_degree, + table_used_fixed, + table_used_instances, + table_used_advices, + ) = evaluate_lc(&lookup.table_expressions); // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) let right_gamma = graph.add_calculation(Calculation::Add( compressed_table_coset, @@ -254,43 +354,21 @@ impl Evaluator { ValueSource::Beta(), )); graph.add_calculation(Calculation::Mul(lc, right_gamma)); - - ev.lookups.push(graph); - } - - // Shuffles - for shuffle in cs.shuffles.iter() { - let evaluate_lc = |expressions: &Vec>, graph: &mut GraphEvaluator| { - let parts = expressions - .iter() - .map(|expr| graph.add_expression(expr)) - .collect(); - graph.add_calculation(Calculation::Horner( - ValueSource::Constant(0), - parts, - ValueSource::Theta(), - )) - }; - - let mut graph_input = GraphEvaluator::default(); - let compressed_input_coset = evaluate_lc(&shuffle.input_expressions, &mut graph_input); - let _ = graph_input.add_calculation(Calculation::Add( - compressed_input_coset, - ValueSource::Gamma(), + ev.lookups.push(( + graph, + max_input_degree + max_table_degree, + ( + merge_unique(input_used_fixed, table_used_fixed), + merge_unique(input_used_instances, table_used_instances), + merge_unique(input_used_advices, table_used_advices), + ), )); - - let mut graph_shuffle = GraphEvaluator::default(); - let compressed_shuffle_coset = - evaluate_lc(&shuffle.shuffle_expressions, &mut graph_shuffle); - let _ = graph_shuffle.add_calculation(Calculation::Add( - compressed_shuffle_coset, - ValueSource::Gamma(), - )); - - ev.shuffles.push(graph_input); - ev.shuffles.push(graph_shuffle); } + // Count the constraints in permutation + let num_sets = (cs.permutation.get_columns().len() + (cs.degree() - 3)) / (cs.degree() - 2); + constraint_idx += 1 + num_sets * 2; + ev.num_y_powers = constraint_idx + 10; ev } @@ -307,297 +385,579 @@ impl Evaluator { gamma: C::ScalarExt, theta: C::ScalarExt, lookups: &[Vec>], - shuffles: &[Vec>], + // shuffles: &[Vec>], permutations: &[permutation::prover::Committed], ) -> Polynomial { let domain = &pk.vk.domain; - let size = domain.extended_len(); - let rot_scale = 1 << (domain.extended_k() - domain.k()); - let fixed = &pk.fixed_cosets[..]; + let size = 1 << domain.k() as usize; + let rot_scale = 1; let extended_omega = domain.get_extended_omega(); + let omega = domain.get_omega(); let isize = size as i32; let one = C::ScalarExt::ONE; - let l0 = &pk.l0; - let l_last = &pk.l_last; - let l_active_row = &pk.l_active_row; let p = &pk.vk.cs.permutation; + let num_parts = domain.extended_len() >> domain.k(); + let num_clusters = (domain.extended_k() - domain.k() + 1) as usize; - // Calculate the advice and instance cosets - let advice: Vec>> = advice_polys - .iter() - .map(|advice_polys| { - advice_polys - .iter() - .map(|poly| domain.coeff_to_extended(poly)) - .collect() - }) - .collect(); - let instance: Vec>> = instance_polys - .iter() - .map(|instance_polys| { - instance_polys - .iter() - .map(|poly| domain.coeff_to_extended(poly)) - .collect() - }) - .collect(); + assert!(self.custom_gate_clusters.len() <= num_clusters); + + // Initialize the the powers of y and constraint counter + let mut y_powers = vec![C::ScalarExt::ONE; self.num_y_powers * instance_polys.len()]; + for i in 1..self.num_y_powers { + y_powers[i] = y_powers[i - 1] * y; + } - let mut values = domain.empty_extended(); + let need_to_compute = |part_idx, cluster_idx| part_idx % (num_parts >> cluster_idx) == 0; + let compute_part_idx_in_cluster = + |part_idx, cluster_idx| part_idx >> (num_clusters - cluster_idx - 1); - // Core expression evaluations - let num_threads = multicore::current_num_threads(); - for ((((advice, instance), lookups), shuffles), permutation) in advice - .iter() - .zip(instance.iter()) - .zip(lookups.iter()) - .zip(shuffles.iter()) - .zip(permutations.iter()) + let mut value_part_clusters = Vec::new(); + value_part_clusters.resize(num_clusters, Vec::new()); + for (cluster_idx, cluster) in value_part_clusters + .iter_mut() + .enumerate() + .take(num_clusters) { - // Custom gates - multicore::scope(|scope| { - let chunk_size = (size + num_threads - 1) / num_threads; - for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() { - let start = thread_idx * chunk_size; - scope.spawn(move |_| { - let mut eval_data = self.custom_gates.instance(); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - *value = self.custom_gates.evaluate( - &mut eval_data, - fixed, - advice, - instance, - challenges, - &beta, - &gamma, - &theta, - &y, - value, - idx, - rot_scale, - isize, - ); + cluster.resize(1 << cluster_idx, domain.empty_lagrange()); + } + + // Calculate the quotient polynomial for each part + let mut current_extended_omega = one; + for part_idx in 0..num_parts { + let mut fixed: Vec>> = + vec![None; pk.fixed_polys.len()]; + let l0 = domain.coeff_to_extended_part(pk.l0.clone(), current_extended_omega); + let l_last = domain.coeff_to_extended_part(pk.l_last.clone(), current_extended_omega); + let l_active_row = + domain.coeff_to_extended_part(pk.l_active_row.clone(), current_extended_omega); + + let mut constraint_idx = 0; + let mut cluster_last_constraint_idx = vec![0; num_clusters]; + + // Core expression evaluations + let num_threads = multicore::current_num_threads(); + for (((advice_polys, instance_polys), lookups), permutation) in advice_polys + .iter() + .zip(instance_polys.iter()) + .zip(lookups.iter()) + .zip(permutations.iter()) + { + // Calculate the advice and instance cosets + let mut advice: Vec>> = + vec![None; advice_polys.len()]; + let mut instance: Vec>> = + vec![None; instance_polys.len()]; + + // Custom gates + for (cluster_idx, custom_gates) in self.custom_gate_clusters.iter().enumerate() { + if !need_to_compute(part_idx, cluster_idx) + || custom_gates.last_value_source.is_none() + { + continue; + } + let values = &mut value_part_clusters[cluster_idx] + [compute_part_idx_in_cluster(part_idx, cluster_idx)]; + for fixed_idx in custom_gates.used_fixed_columns.iter() { + if fixed[*fixed_idx].is_none() { + fixed[*fixed_idx] = Some(domain.coeff_to_extended_part( + pk.fixed_polys[*fixed_idx].clone(), + current_extended_omega, + )); + } + } + for instance_idx in custom_gates.used_instance_columns.iter() { + if instance[*instance_idx].is_none() { + instance[*instance_idx] = Some(domain.coeff_to_extended_part( + instance_polys[*instance_idx].clone(), + current_extended_omega, + )); + } + } + for advice_idx in custom_gates.used_advice_columns.iter() { + if advice[*advice_idx].is_none() { + advice[*advice_idx] = Some(domain.coeff_to_extended_part( + advice_polys[*advice_idx].clone(), + current_extended_omega, + )); + } + } + let fixed_slice = &fixed[..]; + let advice_slice = &advice[..]; + let instance_slice = &instance[..]; + let y_power_slice = &y_powers[..]; + let y_power = y_powers[constraint_idx + custom_gates.first_constraint_idx + - cluster_last_constraint_idx[cluster_idx]]; + multicore::scope(|scope| { + let chunk_size = (size + num_threads - 1) / num_threads; + for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() { + let start = thread_idx * chunk_size; + scope.spawn(move |_| { + let mut eval_data = custom_gates.evaluator.instance(); + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + *value = *value * y_power + + custom_gates.evaluator.evaluate( + &mut eval_data, + fixed_slice, + advice_slice, + instance_slice, + challenges, + y_power_slice, + &beta, + &gamma, + &theta, + idx, + rot_scale, + isize, + ); + } + }); } }); + + // Update the constraint index + cluster_last_constraint_idx[cluster_idx] = + constraint_idx + custom_gates.last_constraint_idx; } - }); - - // Permutations - let sets = &permutation.sets; - if !sets.is_empty() { - let blinding_factors = pk.vk.cs.blinding_factors(); - let last_rotation = Rotation(-((blinding_factors + 1) as i32)); - let chunk_len = pk.vk.cs.degree() - 2; - let delta_start = beta * &C::Scalar::ZETA; - - let first_set = sets.first().unwrap(); - let last_set = sets.last().unwrap(); - - // Permutation constraints - parallelize(&mut values, |values, start| { - let mut beta_term = extended_omega.pow_vartime([start as u64, 0, 0, 0]); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - let r_next = get_rotation_idx(idx, 1, rot_scale, isize); - let r_last = get_rotation_idx(idx, last_rotation.0, rot_scale, isize); - - // Enforce only for the first set. - // l_0(X) * (1 - z_0(X)) = 0 - *value = *value * y - + ((one - first_set.permutation_product_coset[idx]) * l0[idx]); - // Enforce only for the last set. - // l_last(X) * (z_l(X)^2 - z_l(X)) = 0 - *value = *value * y - + ((last_set.permutation_product_coset[idx] - * last_set.permutation_product_coset[idx] - - last_set.permutation_product_coset[idx]) - * l_last[idx]); - // Except for the first set, enforce. - // l_0(X) * (z_i(X) - z_{i-1}(\omega^(last) X)) = 0 - for (set_idx, set) in sets.iter().enumerate() { - if set_idx != 0 { - *value = *value * y - + ((set.permutation_product_coset[idx] - - permutation.sets[set_idx - 1].permutation_product_coset - [r_last]) - * l0[idx]); + constraint_idx += self.num_custom_gate_constraints; + + // Permutations + let sets = &permutation.sets; + if !sets.is_empty() { + let blinding_factors = pk.vk.cs.blinding_factors(); + let last_rotation = Rotation(-((blinding_factors + 1) as i32)); + let chunk_len = pk.vk.cs.degree() - 2; + let delta_start = beta * &C::Scalar::ZETA; + + let permutation_product_cosets: Vec> = + sets.iter() + .map(|set| { + domain.coeff_to_extended_part( + set.permutation_product_poly.clone(), + current_extended_omega, + ) + }) + .collect(); + + let first_set_permutation_product_coset = + permutation_product_cosets.first().unwrap(); + let last_set_permutation_product_coset = + permutation_product_cosets.last().unwrap(); + + // Permutation constraints + constraint_idx += 1; + if need_to_compute(part_idx, 1) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; + parallelize( + &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // Enforce only for the first set. + // l_0(X) * (1 - z_0(X)) = 0, degree = 2 + *value = *value * y_power + + ((one - first_set_permutation_product_coset[idx]) + * l0[idx]); + } + }, + ); + cluster_last_constraint_idx[1] = constraint_idx; + } + + constraint_idx += 1; + if need_to_compute(part_idx, 2) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; + parallelize( + &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // Enforce only for the last set. + // l_last(X) * (z_l(X)^2 - z_l(X)) = 0, degree = 3 + *value = *value * y_power + + ((last_set_permutation_product_coset[idx] + * last_set_permutation_product_coset[idx] + - last_set_permutation_product_coset[idx]) + * l_last[idx]); + } + }, + ); + cluster_last_constraint_idx[2] = constraint_idx; + } + + constraint_idx += sets.len() - 1; + if need_to_compute(part_idx, 1) { + let y_skip = y_powers + [constraint_idx + 1 - sets.len() - cluster_last_constraint_idx[1]]; + parallelize( + &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // Except for the first set, enforce. + // l_0(X) * (z_i(X) - z_{i-1}(\omega^(last) X)) = 0, degree = 2 + let r_last = + get_rotation_idx(idx, last_rotation.0, rot_scale, isize); + + *value *= y_skip; + + for (set_idx, permutation_product_coset) in + permutation_product_cosets.iter().enumerate() + { + if set_idx != 0 { + *value = *value * y + + ((permutation_product_coset[idx] + - permutation_product_cosets[set_idx - 1] + [r_last]) + * l0[idx]); + } + } + } + }, + ); + cluster_last_constraint_idx[1] = constraint_idx; + } + + constraint_idx += sets.len(); + let running_prod_cluster = + Self::compute_cluster_idx(2 + chunk_len, num_clusters - 1); + if need_to_compute(part_idx, running_prod_cluster) { + for column in p.columns.iter() { + match column.column_type() { + Any::Advice(_) => { + let advice = &mut advice[column.index()]; + if (*advice).is_none() { + *advice = Some(domain.coeff_to_extended_part( + advice_polys[column.index()].clone(), + current_extended_omega, + )); + } + } + Any::Instance => { + let instance = &mut instance[column.index()]; + if instance.is_none() { + *instance = Some(domain.coeff_to_extended_part( + instance_polys[column.index()].clone(), + current_extended_omega, + )); + } + } + Any::Fixed => { + let fixed = &mut fixed[column.index()]; + if fixed.is_none() { + *fixed = Some(domain.coeff_to_extended_part( + pk.fixed_polys[column.index()].clone(), + current_extended_omega, + )); + } + } } } - // And for all the sets we enforce: - // (1 - (l_last(X) + l_blind(X))) * ( - // z_i(\omega X) \prod_j (p(X) + \beta s_j(X) + \gamma) - // - z_i(X) \prod_j (p(X) + \delta^j \beta X + \gamma) - // ) - let mut current_delta = delta_start * beta_term; - for ((set, columns), cosets) in sets + + let permutation_cosets: Vec> = pk + .permutation + .polys .iter() - .zip(p.columns.chunks(chunk_len)) - .zip(pk.permutation.cosets.chunks(chunk_len)) - { - let mut left = set.permutation_product_coset[r_next]; - for (values, permutation) in columns - .iter() - .map(|&column| match column.column_type() { - Any::Advice(_) => &advice[column.index()], - Any::Fixed => &fixed[column.index()], - Any::Instance => &instance[column.index()], - }) - .zip(cosets.iter()) - { - left *= values[idx] + beta * permutation[idx] + gamma; - } + .map(|p| { + domain.coeff_to_extended_part(p.clone(), current_extended_omega) + }) + .collect(); - let mut right = set.permutation_product_coset[idx]; - for values in columns.iter().map(|&column| match column.column_type() { - Any::Advice(_) => &advice[column.index()], - Any::Fixed => &fixed[column.index()], - Any::Instance => &instance[column.index()], - }) { - right *= values[idx] + current_delta + gamma; - current_delta *= &C::Scalar::DELTA; - } + let y_skip = y_powers[constraint_idx + - sets.len() + - cluster_last_constraint_idx[running_prod_cluster]]; - *value = *value * y + ((left - right) * l_active_row[idx]); - } - beta_term *= &extended_omega; + parallelize( + &mut value_part_clusters[running_prod_cluster] + [compute_part_idx_in_cluster(part_idx, running_prod_cluster)], + |values, start| { + let mut beta_term = + current_extended_omega * omega.pow_vartime([start as u64]); + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let r_next = get_rotation_idx(idx, 1, rot_scale, isize); + + *value *= y_skip; + + // And for all the sets we enforce: + // (1 - (l_last(X) + l_blind(X))) * ( + // z_i(\omega X) \prod_j (p(X) + \beta s_j(X) + \gamma) + // - z_i(X) \prod_j (p(X) + \delta^j \beta X + \gamma) + // ), degree = 2 + chunk_len + let mut current_delta = delta_start * beta_term; + for ( + (columns, permutation_product_coset), + permutation_coset_chunk, + ) in p + .columns + .chunks(chunk_len) + .zip(permutation_product_cosets.iter()) + .zip(permutation_cosets.chunks(chunk_len)) + { + let mut left = permutation_product_coset[r_next]; + for (values, permutation) in columns + .iter() + .map(|&column| match column.column_type() { + Any::Advice(_) => { + advice[column.index()].as_ref().unwrap() + } + Any::Fixed => { + fixed[column.index()].as_ref().unwrap() + } + Any::Instance => { + instance[column.index()].as_ref().unwrap() + } + }) + .zip(permutation_coset_chunk.iter()) + { + left *= values[idx] + beta * permutation[idx] + gamma; + } + + let mut right = permutation_product_coset[idx]; + for values in columns.iter().map(|&column| { + match column.column_type() { + Any::Advice(_) => { + advice[column.index()].as_ref().unwrap() + } + Any::Fixed => { + fixed[column.index()].as_ref().unwrap() + } + Any::Instance => { + instance[column.index()].as_ref().unwrap() + } + } + }) { + right *= values[idx] + current_delta + gamma; + current_delta *= &C::Scalar::DELTA; + } + + *value = *value * y + ((left - right) * l_active_row[idx]); + } + beta_term *= ω + } + }, + ); + cluster_last_constraint_idx[running_prod_cluster] = constraint_idx; } - }); - } + } + + // Lookups + for (n, lookup) in lookups.iter().enumerate() { + let (lookup_evaluator, max_degree, used_columns) = &self.lookups[n]; + let running_prod_cluster = + Self::compute_cluster_idx(max_degree + 2, num_clusters - 1); + if !need_to_compute(part_idx, 1) + && !need_to_compute(part_idx, 2) + && !need_to_compute(part_idx, running_prod_cluster) + { + constraint_idx += 5; + continue; + } + + // Polynomials required for this lookup. + // Calculated here so these only have to be kept in memory for the short time + // they are actually needed. + let product_coset = pk.vk.domain.coeff_to_extended_part( + lookup.product_poly.clone(), + current_extended_omega, + ); + let permuted_input_coset = pk.vk.domain.coeff_to_extended_part( + lookup.permuted_input_poly.clone(), + current_extended_omega, + ); + let permuted_table_coset = pk.vk.domain.coeff_to_extended_part( + lookup.permuted_table_poly.clone(), + current_extended_omega, + ); - // Lookups - for (n, lookup) in lookups.iter().enumerate() { - // Polynomials required for this lookup. - // Calculated here so these only have to be kept in memory for the short time - // they are actually needed. - let product_coset = pk.vk.domain.coeff_to_extended(&lookup.product_poly); - let permuted_input_coset = - pk.vk.domain.coeff_to_extended(&lookup.permuted_input_poly); - let permuted_table_coset = - pk.vk.domain.coeff_to_extended(&lookup.permuted_table_poly); - - // Lookup constraints - parallelize(&mut values, |values, start| { - let lookup_evaluator = &self.lookups[n]; - let mut eval_data = lookup_evaluator.instance(); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - - let table_value = lookup_evaluator.evaluate( - &mut eval_data, - fixed, - advice, - instance, - challenges, - &beta, - &gamma, - &theta, - &y, - &C::ScalarExt::ZERO, - idx, - rot_scale, - isize, + // Lookup constraints + constraint_idx += 1; + if need_to_compute(part_idx, 1) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; + + parallelize( + &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // l_0(X) * (1 - z(X)) = 0, degree = 2 + *value = + *value * y_power + ((one - product_coset[idx]) * l0[idx]); + } + }, ); + cluster_last_constraint_idx[1] = constraint_idx; + } - let r_next = get_rotation_idx(idx, 1, rot_scale, isize); - let r_prev = get_rotation_idx(idx, -1, rot_scale, isize); - - let a_minus_s = permuted_input_coset[idx] - permuted_table_coset[idx]; - // l_0(X) * (1 - z(X)) = 0 - *value = *value * y + ((one - product_coset[idx]) * l0[idx]); - // l_last(X) * (z(X)^2 - z(X)) = 0 - *value = *value * y - + ((product_coset[idx] * product_coset[idx] - product_coset[idx]) - * l_last[idx]); - // (1 - (l_last(X) + l_blind(X))) * ( - // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) - // - z(X) (\theta^{m-1} a_0(X) + ... + a_{m-1}(X) + \beta) - // (\theta^{m-1} s_0(X) + ... + s_{m-1}(X) + \gamma) - // ) = 0 - *value = *value * y - + ((product_coset[r_next] - * (permuted_input_coset[idx] + beta) - * (permuted_table_coset[idx] + gamma) - - product_coset[idx] * table_value) - * l_active_row[idx]); - // Check that the first values in the permuted input expression and permuted - // fixed expression are the same. - // l_0(X) * (a'(X) - s'(X)) = 0 - *value = *value * y + (a_minus_s * l0[idx]); - // Check that each value in the permuted lookup input expression is either - // equal to the value above it, or the value at the same index in the - // permuted table expression. - // (1 - (l_last + l_blind)) * (a′(X) − s′(X))⋅(a′(X) − a′(\omega^{-1} X)) = 0 - *value = *value * y - + (a_minus_s - * (permuted_input_coset[idx] - permuted_input_coset[r_prev]) - * l_active_row[idx]); + constraint_idx += 1; + if need_to_compute(part_idx, 2) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; + parallelize( + &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // l_last(X) * (z(X)^2 - z(X)) = 0, degree = 3 + *value = *value * y_power + + ((product_coset[idx] * product_coset[idx] + - product_coset[idx]) + * l_last[idx]); + } + }, + ); + cluster_last_constraint_idx[2] = constraint_idx; } - }); - } + constraint_idx += 1; + if need_to_compute(part_idx, running_prod_cluster) { + for fixed_column in used_columns.0.iter() { + let fixed = &mut fixed[*fixed_column]; + if fixed.is_none() { + *fixed = Some(domain.coeff_to_extended_part( + pk.fixed_polys[*fixed_column].clone(), + current_extended_omega, + )); + } + } + for instance_column in used_columns.1.iter() { + let instance = &mut instance[*instance_column]; + if instance.is_none() { + *instance = Some(domain.coeff_to_extended_part( + instance_polys[*instance_column].clone(), + current_extended_omega, + )); + } + } + + for advice_column in used_columns.2.iter() { + let advice = &mut advice[*advice_column]; + if (*advice).is_none() { + *advice = Some(domain.coeff_to_extended_part( + advice_polys[*advice_column].clone(), + current_extended_omega, + )); + } + } + + let y_power = y_powers + [constraint_idx - cluster_last_constraint_idx[running_prod_cluster]]; + let fixed_slice = &fixed[..]; + let advice_slice = &advice[..]; + let instance_slice = &instance[..]; + let y_power_slice = &y_powers[..]; + parallelize( + &mut value_part_clusters[running_prod_cluster] + [compute_part_idx_in_cluster(part_idx, running_prod_cluster)], + |values, start| { + let mut eval_data = lookup_evaluator.instance(); + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let table_value = lookup_evaluator.evaluate( + &mut eval_data, + fixed_slice, + advice_slice, + instance_slice, + challenges, + y_power_slice, + &beta, + &gamma, + &theta, + idx, + rot_scale, + isize, + ); - // Shuffle constraints - for (n, shuffle) in shuffles.iter().enumerate() { - let product_coset = pk.vk.domain.coeff_to_extended(&shuffle.product_poly); - - // Shuffle constraints - parallelize(&mut values, |values, start| { - let input_evaluator = &self.shuffles[2 * n]; - let shuffle_evaluator = &self.shuffles[2 * n + 1]; - let mut eval_data_input = shuffle_evaluator.instance(); - let mut eval_data_shuffle = shuffle_evaluator.instance(); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - - let input_value = input_evaluator.evaluate( - &mut eval_data_input, - fixed, - advice, - instance, - challenges, - &beta, - &gamma, - &theta, - &y, - &C::ScalarExt::ZERO, - idx, - rot_scale, - isize, + let r_next = get_rotation_idx(idx, 1, rot_scale, isize); + + // (1 - (l_last(X) + l_blind(X))) * ( + // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) + // - z(X) (\theta^{m-1} a_0(X) + ... + a_{m-1}(X) + \beta) + // (\theta^{m-1} s_0(X) + ... + s_{m-1}(X) + \gamma) + // ) = 0, degree = 2 + max(deg(a)) + max(deg(s)) + *value = *value * y_power + + ((product_coset[r_next] + * (permuted_input_coset[idx] + beta) + * (permuted_table_coset[idx] + gamma) + - product_coset[idx] * table_value) + * l_active_row[idx]); + } + }, ); + cluster_last_constraint_idx[running_prod_cluster] = constraint_idx; + } - let shuffle_value = shuffle_evaluator.evaluate( - &mut eval_data_shuffle, - fixed, - advice, - instance, - challenges, - &beta, - &gamma, - &theta, - &y, - &C::ScalarExt::ZERO, - idx, - rot_scale, - isize, + constraint_idx += 1; + if need_to_compute(part_idx, 1) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; + parallelize( + &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let a_minus_s = + permuted_input_coset[idx] - permuted_table_coset[idx]; + // Check that the first values in the permuted input expression and permuted + // fixed expression are the same. + // l_0(X) * (a'(X) - s'(X)) = 0, degree = 2 + *value = *value * y_power + (a_minus_s * l0[idx]); + } + }, ); + cluster_last_constraint_idx[1] = constraint_idx; + } + + constraint_idx += 1; + if need_to_compute(part_idx, 2) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; + parallelize( + &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let r_prev = get_rotation_idx(idx, -1, rot_scale, isize); - let r_next = get_rotation_idx(idx, 1, rot_scale, isize); - - // l_0(X) * (1 - z(X)) = 0 - *value = *value * y + ((one - product_coset[idx]) * l0[idx]); - // l_last(X) * (z(X)^2 - z(X)) = 0 - *value = *value * y - + ((product_coset[idx] * product_coset[idx] - product_coset[idx]) - * l_last[idx]); - // (1 - (l_last(X) + l_blind(X))) * (z(\omega X) (s(X) + \gamma) - z(X) (a(X) + \gamma)) = 0 - *value = *value * y - + l_active_row[idx] - * (product_coset[r_next] * shuffle_value - - product_coset[idx] * input_value) + // Check that each value in the permuted lookup input expression is either + // equal to the value above it, or the value at the same index in the + // permuted table expression. + // (1 - (l_last + l_blind)) * (a′(X) − s′(X))⋅(a′(X) − a′(\omega^{-1} X)) = 0, degree = 3 + let a_minus_s = + permuted_input_coset[idx] - permuted_table_coset[idx]; + *value = *value * y_power + + (a_minus_s + * (permuted_input_coset[idx] + - permuted_input_coset[r_prev]) + * l_active_row[idx]); + } + }, + ); + cluster_last_constraint_idx[2] = constraint_idx; } - }); + } } + // Align the constraints by different powers of y. + for (i, cluster) in value_part_clusters.iter_mut().enumerate() { + if need_to_compute(part_idx, i) && cluster_last_constraint_idx[i] > 0 { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[i]]; + parallelize( + &mut cluster[compute_part_idx_in_cluster(part_idx, i)], + |values, _| { + for value in values.iter_mut() { + *value *= y_power; + } + }, + ); + } + } + current_extended_omega *= extended_omega; + } + domain.lagrange_vecs_to_extended(value_part_clusters) + } + + fn compute_cluster_idx(degree: usize, max_cluster_idx: usize) -> usize { + let mut idx = (31 - (degree as u32).leading_zeros()) as usize; + if 1 << idx < degree { + idx += 1; } - values + std::cmp::min(max_cluster_idx, idx) } } @@ -780,15 +1140,14 @@ impl GraphEvaluator { pub fn evaluate( &self, data: &mut EvaluationData, - fixed: &[Polynomial], - advice: &[Polynomial], - instance: &[Polynomial], + fixed: &[Option>], + advice: &[Option>], + instance: &[Option>], challenges: &[C::ScalarExt], + y_powers: &[C::ScalarExt], beta: &C::ScalarExt, gamma: &C::ScalarExt, theta: &C::ScalarExt, - y: &C::ScalarExt, - previous_value: &C::ScalarExt, idx: usize, rot_scale: i32, isize: i32, @@ -808,11 +1167,10 @@ impl GraphEvaluator { advice, instance, challenges, + y_powers, beta, gamma, theta, - y, - previous_value, ); } diff --git a/halo2_proofs/src/plonk/keygen.rs b/halo2_proofs/src/plonk/keygen.rs index 81ca52cf05..a5b7e87e92 100644 --- a/halo2_proofs/src/plonk/keygen.rs +++ b/halo2_proofs/src/plonk/keygen.rs @@ -333,11 +333,6 @@ where .map(|poly| vk.domain.lagrange_to_coeff(poly.clone())) .collect(); - let fixed_cosets = fixed_polys - .iter() - .map(|poly| vk.domain.coeff_to_extended(poly)) - .collect(); - let permutation_pk = assembly .permutation .build_pk(params, &vk.domain, &cs.permutation); @@ -347,7 +342,6 @@ where let mut l0 = vk.domain.empty_lagrange(); l0[0] = C::Scalar::ONE; let l0 = vk.domain.lagrange_to_coeff(l0); - let l0 = vk.domain.coeff_to_extended(&l0); // Compute l_blind(X) which evaluates to 1 for each blinding factor row // and 0 otherwise over the domain. @@ -355,19 +349,15 @@ where for evaluation in l_blind[..].iter_mut().rev().take(cs.blinding_factors()) { *evaluation = C::Scalar::ONE; } - let l_blind = vk.domain.lagrange_to_coeff(l_blind); - let l_blind = vk.domain.coeff_to_extended(&l_blind); // Compute l_last(X) which evaluates to 1 on the first inactive row (just // before the blinding factors) and 0 otherwise over the domain let mut l_last = vk.domain.empty_lagrange(); l_last[params.n() as usize - cs.blinding_factors() - 1] = C::Scalar::ONE; - let l_last = vk.domain.lagrange_to_coeff(l_last); - let l_last = vk.domain.coeff_to_extended(&l_last); // Compute l_active_row(X) let one = C::Scalar::ONE; - let mut l_active_row = vk.domain.empty_extended(); + let mut l_active_row = vk.domain.empty_lagrange(); parallelize(&mut l_active_row, |values, start| { for (i, value) in values.iter_mut().enumerate() { let idx = i + start; @@ -375,6 +365,9 @@ where } }); + let l_last = vk.domain.lagrange_to_coeff(l_last); + let l_active_row = vk.domain.lagrange_to_coeff(l_active_row); + // Compute the optimized evaluation data structure let ev = Evaluator::new(&vk.cs); @@ -385,7 +378,6 @@ where l_active_row, fixed_values: fixed, fixed_polys, - fixed_cosets, permutation: permutation_pk, ev, }) diff --git a/halo2_proofs/src/plonk/permutation.rs b/halo2_proofs/src/plonk/permutation.rs index 6c5896b978..479f72f4ab 100644 --- a/halo2_proofs/src/plonk/permutation.rs +++ b/halo2_proofs/src/plonk/permutation.rs @@ -7,7 +7,7 @@ use crate::{ polynomial_slice_byte_length, read_polynomial_vec, write_polynomial_slice, SerdeCurveAffine, SerdePrimeField, }, - poly::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial}, + poly::{Coeff, LagrangeCoeff, Polynomial}, SerdeFormat, }; @@ -126,39 +126,33 @@ impl VerifyingKey { #[derive(Clone, Debug)] pub(crate) struct ProvingKey { permutations: Vec>, - polys: Vec>, - pub(super) cosets: Vec>, + pub(super) polys: Vec>, } impl ProvingKey where C::Scalar: SerdePrimeField, { - /// Reads proving key for a single permutation argument from buffer using `Polynomial::read`. + /// Reads proving key for a single permutation argument from buffer using `Polynomial::read`. pub(super) fn read(reader: &mut R, format: SerdeFormat) -> Self { let permutations = read_polynomial_vec(reader, format); let polys = read_polynomial_vec(reader, format); - let cosets = read_polynomial_vec(reader, format); ProvingKey { permutations, polys, - cosets, } } - /// Writes proving key for a single permutation argument to buffer using `Polynomial::write`. + /// Writes proving key for a single permutation argument to buffer using `Polynomial::write`. pub(super) fn write(&self, writer: &mut W, format: SerdeFormat) { write_polynomial_slice(&self.permutations, writer, format); write_polynomial_slice(&self.polys, writer, format); - write_polynomial_slice(&self.cosets, writer, format); } } impl ProvingKey { /// Gets the total number of bytes in the serialization of `self` pub(super) fn bytes_length(&self) -> usize { - polynomial_slice_byte_length(&self.permutations) - + polynomial_slice_byte_length(&self.polys) - + polynomial_slice_byte_length(&self.cosets) + polynomial_slice_byte_length(&self.permutations) + polynomial_slice_byte_length(&self.polys) } } diff --git a/halo2_proofs/src/plonk/permutation/keygen.rs b/halo2_proofs/src/plonk/permutation/keygen.rs index 94fa4518db..1bad10f3b8 100644 --- a/halo2_proofs/src/plonk/permutation/keygen.rs +++ b/halo2_proofs/src/plonk/permutation/keygen.rs @@ -394,21 +394,9 @@ pub(crate) fn build_pk<'params, C: CurveAffine, P: Params<'params, C>>( }); } - let mut cosets = vec![domain.empty_extended(); p.columns.len()]; - { - parallelize(&mut cosets, |o, start| { - for (x, coset) in o.iter_mut().enumerate() { - let i = start + x; - let poly = &polys[i]; - *coset = domain.coeff_to_extended(poly); - } - }); - } - ProvingKey { permutations, polys, - cosets, } } diff --git a/halo2_proofs/src/plonk/permutation/prover.rs b/halo2_proofs/src/plonk/permutation/prover.rs index 3565723399..28dfa7f4eb 100644 --- a/halo2_proofs/src/plonk/permutation/prover.rs +++ b/halo2_proofs/src/plonk/permutation/prover.rs @@ -13,14 +13,13 @@ use crate::{ plonk::{self, Error}, poly::{ commitment::{Blind, Params}, - Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, ProverQuery, Rotation, + Coeff, LagrangeCoeff, Polynomial, ProverQuery, Rotation, }, transcript::{EncodedChallenge, TranscriptWrite}, }; pub(crate) struct CommittedSet { pub(crate) permutation_product_poly: Polynomial, - pub(crate) permutation_product_coset: Polynomial, permutation_product_blind: Blind, } @@ -173,8 +172,6 @@ impl Argument { let z = domain.lagrange_to_coeff(z); let permutation_product_poly = z.clone(); - let permutation_product_coset = domain.coeff_to_extended(&z); - let permutation_product_commitment = permutation_product_commitment_projective.to_affine(); @@ -183,7 +180,6 @@ impl Argument { sets.push(CommittedSet { permutation_product_poly, - permutation_product_coset, permutation_product_blind, }); } diff --git a/halo2_proofs/src/plonk/prover.rs b/halo2_proofs/src/plonk/prover.rs index ef48e77e08..95bfa5ef07 100644 --- a/halo2_proofs/src/plonk/prover.rs +++ b/halo2_proofs/src/plonk/prover.rs @@ -17,8 +17,8 @@ use super::{ Advice, Any, Assignment, Challenge, Circuit, Column, ConstraintSystem, Fixed, FloorPlanner, Instance, Selector, }, - lookup, permutation, shuffle, vanishing, ChallengeBeta, ChallengeGamma, ChallengeTheta, - ChallengeX, ChallengeY, Error, ProvingKey, + lookup, permutation, vanishing, ChallengeBeta, ChallengeGamma, ChallengeTheta, ChallengeX, + ChallengeY, Error, ProvingKey, }; use crate::{ @@ -545,40 +545,6 @@ where #[cfg(feature = "profile")] end_timer!(phase3b_time); - #[cfg(feature = "profile")] - let shuffle_time = start_timer!(|| "Shuffles"); - let shuffles: Vec>> = instance - .iter() - .zip(advice.iter()) - .map(|(instance, advice)| -> Vec<_> { - // Compress expressions for each shuffle - pk.vk - .cs - .shuffles - .iter() - .map(|shuffle| { - shuffle - .commit_product( - pk, - params, - domain, - theta, - gamma, - &advice.advice_polys, - &pk.fixed_values, - &instance.instance_values, - &challenges, - &mut rng, - transcript, - ) - .unwrap() - }) - .collect() - }) - .collect(); - #[cfg(feature = "profile")] - end_timer!(shuffle_time); - #[cfg(feature = "profile")] let vanishing_time = start_timer!(|| "Commit to vanishing argument's random poly"); // Commit to the vanishing argument's random polynomial for blinding h(x_3) @@ -632,7 +598,6 @@ where *gamma, *theta, &lookups, - &shuffles, &permutations, ); #[cfg(feature = "profile")] @@ -730,24 +695,12 @@ where #[cfg(feature = "profile")] end_timer!(eval_time); - // Evaluate the shuffles, if any, at omega^i x. - let shuffles: Vec>> = shuffles - .into_iter() - .map(|shuffles| -> Result, _> { - shuffles - .into_iter() - .map(|p| p.evaluate(pk, x, transcript)) - .collect::, _>>() - }) - .collect::, _>>()?; - let instances = instance .iter() .zip(advice.iter()) .zip(permutations.iter()) .zip(lookups.iter()) - .zip(shuffles.iter()) - .flat_map(|((((instance, advice), permutation), lookups), shuffles)| { + .flat_map(|(((instance, advice), permutation), lookups)| { iter::empty() .chain( P::QUERY_INSTANCE @@ -774,7 +727,6 @@ where ) .chain(permutation.open(pk, x)) .chain(lookups.iter().flat_map(move |p| p.open(pk, x))) - .chain(shuffles.iter().flat_map(move |p| p.open(pk, x))) }) .chain( pk.vk diff --git a/halo2_proofs/src/plonk/verifier.rs b/halo2_proofs/src/plonk/verifier.rs index 76675bcdfa..068c6a93e9 100644 --- a/halo2_proofs/src/plonk/verifier.rs +++ b/halo2_proofs/src/plonk/verifier.rs @@ -157,17 +157,6 @@ where }) .collect::, _>>()?; - let shuffles_committed = (0..num_proofs) - .map(|_| -> Result, _> { - // Hash each shuffle product commitment - vk.cs - .shuffles - .iter() - .map(|argument| argument.read_product_commitment(transcript)) - .collect::, _>>() - }) - .collect::, _>>()?; - let vanishing = vanishing::Argument::read_commitments_before_y(transcript)?; // Sample y challenge, which keeps the gates linearly independent. @@ -250,16 +239,6 @@ where }) .collect::, _>>()?; - let shuffles_evaluated = shuffles_committed - .into_iter() - .map(|shuffles| -> Result, _> { - shuffles - .into_iter() - .map(|shuffle| shuffle.evaluate(transcript)) - .collect::, _>>() - }) - .collect::, _>>()?; - // This check ensures the circuit is satisfied so long as the polynomial // commitments open to the correct values. let vanishing = { @@ -283,80 +262,59 @@ where .zip(instance_evals.iter()) .zip(permutations_evaluated.iter()) .zip(lookups_evaluated.iter()) - .zip(shuffles_evaluated.iter()) - .flat_map( - |((((advice_evals, instance_evals), permutation), lookups), shuffles)| { - let challenges = &challenges; - let fixed_evals = &fixed_evals; - std::iter::empty() - // Evaluate the circuit using the custom gates provided - .chain(vk.cs.gates.iter().flat_map(move |gate| { - gate.polynomials().iter().map(move |poly| { - poly.evaluate( - &|scalar| scalar, - &|_| { - panic!("virtual selectors are removed during optimization") - }, - &|query| fixed_evals[query.index.unwrap()], - &|query| advice_evals[query.index.unwrap()], - &|query| instance_evals[query.index.unwrap()], - &|challenge| challenges[challenge.index()], - &|a| -a, - &|a, b| a + &b, - &|a, b| a * &b, - &|a, scalar| a * &scalar, - ) - }) - })) - .chain(permutation.expressions( - vk, - &vk.cs.permutation, - &permutations_common, - advice_evals, - fixed_evals, - instance_evals, - l_0, - l_last, - l_blind, - beta, - gamma, - x, - )) - .chain(lookups.iter().zip(vk.cs.lookups.iter()).flat_map( - move |(p, argument)| { - p.expressions( - l_0, - l_last, - l_blind, - argument, - theta, - beta, - gamma, - advice_evals, - fixed_evals, - instance_evals, - challenges, - ) - }, - )) - .chain(shuffles.iter().zip(vk.cs.shuffles.iter()).flat_map( - move |(p, argument)| { - p.expressions( - l_0, - l_last, - l_blind, - argument, - theta, - gamma, - advice_evals, - fixed_evals, - instance_evals, - challenges, - ) - }, - )) - }, - ); + .flat_map(|(((advice_evals, instance_evals), permutation), lookups)| { + let challenges = &challenges; + let fixed_evals = &fixed_evals; + std::iter::empty() + // Evaluate the circuit using the custom gates provided + .chain(vk.cs.gates.iter().flat_map(move |gate| { + gate.polynomials().iter().map(move |poly| { + poly.evaluate( + &|scalar| scalar, + &|_| panic!("virtual selectors are removed during optimization"), + &|query| fixed_evals[query.index.unwrap()], + &|query| advice_evals[query.index.unwrap()], + &|query| instance_evals[query.index.unwrap()], + &|challenge| challenges[challenge.index()], + &|a| -a, + &|a, b| a + &b, + &|a, b| a * &b, + &|a, scalar| a * &scalar, + ) + }) + })) + .chain(permutation.expressions( + vk, + &vk.cs.permutation, + &permutations_common, + advice_evals, + fixed_evals, + instance_evals, + l_0, + l_last, + l_blind, + beta, + gamma, + x, + )) + .chain(lookups.iter().zip(vk.cs.lookups.iter()).flat_map( + move |(p, argument)| { + p.expressions( + l_0, + l_last, + l_blind, + argument, + theta, + beta, + gamma, + advice_evals, + fixed_evals, + instance_evals, + challenges, + ) + }, + )) + }); vanishing.verify(params, expressions, y, xn) }; @@ -368,20 +326,13 @@ where .zip(advice_evals.iter()) .zip(permutations_evaluated.iter()) .zip(lookups_evaluated.iter()) - .zip(shuffles_evaluated.iter()) .flat_map( |( ( - ( - ( - ((instance_commitments, instance_evals), advice_commitments), - advice_evals, - ), - permutation, - ), - lookups, + (((instance_commitments, instance_evals), advice_commitments), advice_evals), + permutation, ), - shuffles, + lookups, )| { iter::empty() .chain( @@ -409,7 +360,6 @@ where )) .chain(permutation.queries(vk, x)) .chain(lookups.iter().flat_map(move |p| p.queries(vk, x))) - .chain(shuffles.iter().flat_map(move |p| p.queries(vk, x))) }, ) .chain( diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index 7e285c6d23..3e88e06c0a 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -2,8 +2,8 @@ //! domain that is of a suitable size for the application. use crate::{ - arithmetic::{best_fft, parallelize, parallelize_count}, - multicore, + arithmetic::{best_fft, parallelize}, + fft::recursive::FFTData, plonk::Assigned, }; @@ -11,395 +11,7 @@ use super::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation}; use group::ff::{BatchInvert, Field, WithSmallOrderMulGroup}; -use std::{env::var, marker::PhantomData}; - -fn get_fft_mode() -> usize { - var("FFT_MODE") - .unwrap_or_else(|_| "1".to_string()) - .parse() - .expect("Cannot parse FFT_MODE env var as usize") -} - -/// FFTStage -#[derive(Clone, Debug)] -pub struct FFTStage { - radix: usize, - length: usize, -} - -/// FFT stages -pub fn get_stages(size: usize, radixes: Vec) -> Vec { - let mut stages: Vec = vec![]; - - let mut n = size; - - // Use the specified radices - for &radix in &radixes { - n /= radix; - stages.push(FFTStage { radix, length: n }); - } - - // Fill in the rest of the tree if needed - let mut p = 2; - while n > 1 { - while n % p != 0 { - if p == 4 { - p = 2; - } - } - n /= p; - stages.push(FFTStage { - radix: p, - length: n, - }); - } - - /*for i in 0..stages.len() { - println!("Stage {}: {}, {}", i, stages[i].radix, stages[i].length); - }*/ - - stages -} - -/// FFTData -#[derive(Clone, Debug)] -struct FFTData { - n: usize, - - stages: Vec, - - f_twiddles: Vec>, - inv_twiddles: Vec>, - //scratch: Vec, -} - -impl FFTData { - /// Create FFT data - pub fn new(n: usize, omega: F, omega_inv: F) -> Self { - let stages = get_stages(n, vec![]); - let mut f_twiddles = vec![]; - let mut inv_twiddles = vec![]; - let mut scratch = vec![F::ONE; n]; - - // Generate stage twiddles - for inv in 0..2 { - let inverse = inv == 0; - let o = if inverse { omega_inv } else { omega }; - let stage_twiddles = if inverse { - &mut inv_twiddles - } else { - &mut f_twiddles - }; - - let twiddles = &mut scratch; - - // Twiddles - parallelize(twiddles, |twiddles, start| { - let w_m = o; - let mut w = o.pow_vartime([start as u64]); - for value in twiddles.iter_mut() { - *value = w; - w *= w_m; - } - }); - - // Re-order twiddles for cache friendliness - let num_stages = stages.len(); - stage_twiddles.resize(num_stages, vec![]); - for l in 0..num_stages { - let radix = stages[l].radix; - let stage_length = stages[l].length; - - let num_twiddles = stage_length * (radix - 1); - stage_twiddles[l].resize(num_twiddles + 1, F::ZERO); - - // Set j - stage_twiddles[l][num_twiddles] = twiddles[(twiddles.len() * 3) / 4]; - - let stride = n / (stage_length * radix); - let mut tws = vec![0usize; radix - 1]; - for i in 0..stage_length { - for j in 0..radix - 1 { - stage_twiddles[l][i * (radix - 1) + j] = twiddles[tws[j]]; - tws[j] += (j + 1) * stride; - } - } - } - } - - Self { - n, - stages, - f_twiddles, - inv_twiddles, - //scratch, - } - } -} - -/// Radix 2 butterfly -pub fn butterfly_2(out: &mut [F], twiddles: &[F], stage_length: usize) { - let mut out_offset = 0; - let mut out_offset2 = stage_length; - - let t = out[out_offset2]; - out[out_offset2] = out[out_offset] - t; - out[out_offset] += t; - out_offset2 += 1; - out_offset += 1; - - for twiddle in twiddles[1..stage_length].iter() { - let t = *twiddle * out[out_offset2]; - out[out_offset2] = out[out_offset] - t; - out[out_offset] += t; - out_offset2 += 1; - out_offset += 1; - } -} - -/// Radix 2 butterfly -fn butterfly_2_parallel( - out: &mut [F], - twiddles: &[F], - _stage_length: usize, - num_threads: usize, -) { - let n = out.len(); - let mut chunk = n / num_threads; - if chunk < num_threads { - chunk = n; - } - - multicore::scope(|scope| { - let (part_a, part_b) = out.split_at_mut(n / 2); - for (i, (part0, part1)) in part_a - .chunks_mut(chunk) - .zip(part_b.chunks_mut(chunk)) - .enumerate() - { - scope.spawn(move |_| { - let offset = i * chunk; - for k in 0..part0.len() { - let t = twiddles[offset + k] * part1[k]; - part1[k] = part0[k] - t; - part0[k] += t; - } - }); - } - }); -} - -/// Radix 4 butterfly -pub fn butterfly_4(out: &mut [F], twiddles: &[F], stage_length: usize) { - let j = twiddles[twiddles.len() - 1]; - let mut tw = 0; - - /* Case twiddle == one */ - { - let i0 = 0; - let i1 = stage_length; - let i2 = stage_length * 2; - let i3 = stage_length * 3; - - let z0 = out[i0]; - let z1 = out[i1]; - let z2 = out[i2]; - let z3 = out[i3]; - - let t1 = z0 + z2; - let t2 = z1 + z3; - let t3 = z0 - z2; - let t4j = j * (z1 - z3); - - out[i0] = t1 + t2; - out[i1] = t3 - t4j; - out[i2] = t1 - t2; - out[i3] = t3 + t4j; - - tw += 3; - } - - for k in 1..stage_length { - let i0 = k; - let i1 = k + stage_length; - let i2 = k + stage_length * 2; - let i3 = k + stage_length * 3; - - let z0 = out[i0]; - let z1 = out[i1] * twiddles[tw]; - let z2 = out[i2] * twiddles[tw + 1]; - let z3 = out[i3] * twiddles[tw + 2]; - - let t1 = z0 + z2; - let t2 = z1 + z3; - let t3 = z0 - z2; - let t4j = j * (z1 - z3); - - out[i0] = t1 + t2; - out[i1] = t3 - t4j; - out[i2] = t1 - t2; - out[i3] = t3 + t4j; - - tw += 3; - } -} - -/// Radix 4 butterfly -pub fn butterfly_4_parallel( - out: &mut [F], - twiddles: &[F], - _stage_length: usize, - num_threads: usize, -) { - let j = twiddles[twiddles.len() - 1]; - - let n = out.len(); - let mut chunk = n / num_threads; - if chunk < num_threads { - chunk = n; - } - multicore::scope(|scope| { - //let mut parts: Vec<&mut [F]> = out.chunks_mut(4).collect(); - //out.chunks_mut(4).map(|c| c.chunks_mut(chunk)).fold(predicate) - let (part_a, part_b) = out.split_at_mut(n / 2); - let (part_aa, part_ab) = part_a.split_at_mut(n / 4); - let (part_ba, part_bb) = part_b.split_at_mut(n / 4); - for (i, (((part0, part1), part2), part3)) in part_aa - .chunks_mut(chunk) - .zip(part_ab.chunks_mut(chunk)) - .zip(part_ba.chunks_mut(chunk)) - .zip(part_bb.chunks_mut(chunk)) - .enumerate() - { - scope.spawn(move |_| { - let offset = i * chunk; - let mut tw = offset * 3; - for k in 0..part1.len() { - let z0 = part0[k]; - let z1 = part1[k] * twiddles[tw]; - let z2 = part2[k] * twiddles[tw + 1]; - let z3 = part3[k] * twiddles[tw + 2]; - - let t1 = z0 + z2; - let t2 = z1 + z3; - let t3 = z0 - z2; - let t4j = j * (z1 - z3); - - part0[k] = t1 + t2; - part1[k] = t3 - t4j; - part2[k] = t1 - t2; - part3[k] = t3 + t4j; - - tw += 3; - } - }); - } - }); -} - -/// Inner recursion -#[allow(clippy::too_many_arguments)] -fn recursive_fft_inner( - data_in: &[F], - data_out: &mut [F], - twiddles: &Vec>, - stages: &Vec, - in_offset: usize, - stride: usize, - level: usize, - num_threads: usize, -) { - let radix = stages[level].radix; - let stage_length = stages[level].length; - - if num_threads > 1 { - if stage_length == 1 { - for i in 0..radix { - data_out[i] = data_in[in_offset + i * stride]; - } - } else { - let num_threads_recursive = if num_threads >= radix { - radix - } else { - num_threads - }; - parallelize_count(data_out, num_threads_recursive, |data_out, i| { - let num_threads_in_recursion = if num_threads < radix { - 1 - } else { - (num_threads + i) / radix - }; - recursive_fft_inner( - data_in, - data_out, - twiddles, - stages, - in_offset + i * stride, - stride * radix, - level + 1, - num_threads_in_recursion, - ) - }); - } - match radix { - 2 => butterfly_2_parallel(data_out, &twiddles[level], stage_length, num_threads), - 4 => butterfly_4_parallel(data_out, &twiddles[level], stage_length, num_threads), - _ => unimplemented!("radix unsupported"), - } - } else { - if stage_length == 1 { - for i in 0..radix { - data_out[i] = data_in[in_offset + i * stride]; - } - } else { - for i in 0..radix { - recursive_fft_inner( - data_in, - &mut data_out[i * stage_length..(i + 1) * stage_length], - twiddles, - stages, - in_offset + i * stride, - stride * radix, - level + 1, - num_threads, - ); - } - } - match radix { - 2 => butterfly_2(data_out, &twiddles[level], stage_length), - 4 => butterfly_4(data_out, &twiddles[level], stage_length), - _ => unimplemented!("radix unsupported"), - } - } -} - -fn recursive_fft(data: &FFTData, data_in: &mut Vec, inverse: bool) { - let num_threads = multicore::current_num_threads(); - //let start = start_measure(format!("recursive fft {} ({})", data_in.len(), num_threads), false); - - // TODO: reuse scratch buffer between FFTs - //let start_mem = start_measure(format!("alloc"), false); - let mut scratch = vec![F::ZERO; data_in.len()]; - //stop_measure(start_mem); - - recursive_fft_inner( - data_in, - &mut /*data.*/scratch, - if inverse { - &data.inv_twiddles - } else { - &data.f_twiddles - }, - &data.stages, - 0, - 1, - 0, - num_threads, - ); - - // Will simply swap the vector's buffer, no data is actually copied - std::mem::swap(data_in, &mut /*data.*/scratch); -} +use std::marker::PhantomData; /// This structure contains precomputed constants and other details needed for /// performing operations on an evaluation domain of size $2^k$ and an extended @@ -581,6 +193,32 @@ impl> EvaluationDomain { } } + /// Obtains a polynomial in ExtendedLagrange form when given a vector of + /// Lagrange polynomials with total size `extended_n`; panics if the + /// provided vector is the wrong length. + pub fn lagrange_vec_to_extended( + &self, + values: Vec>, + ) -> Polynomial { + assert_eq!(values.len(), self.extended_len() >> self.k); + assert_eq!(values[0].len(), self.n as usize); + + // transpose the values in parallel + let mut transposed = vec![vec![F::ZERO; values.len()]; self.n as usize]; + values.into_iter().enumerate().for_each(|(i, p)| { + parallelize(&mut transposed, |transposed, start| { + for (transposed, p) in transposed.iter_mut().zip(p.values[start..].iter()) { + transposed[i] = *p; + } + }); + }); + + Polynomial { + values: transposed.into_iter().flatten().collect(), + _marker: PhantomData, + } + } + /// Returns an empty (zero) polynomial in the coefficient basis pub fn empty_coeff(&self) -> Polynomial { Polynomial { @@ -669,6 +307,82 @@ impl> EvaluationDomain { } } + /// This takes us from an n-length coefficient vector into parts of the + /// extended evaluation domain. For example, for a polynomial with size n, + /// and an extended domain of size mn, we can compute all parts + /// independently, which are + /// `FFT(f(zeta * X), n)` + /// `FFT(f(zeta * extended_omega * X), n)` + /// ... + /// `FFT(f(zeta * extended_omega^{m-1} * X), n)` + pub fn coeff_to_extended_parts( + &self, + a: &Polynomial, + ) -> Vec> { + assert_eq!(a.values.len(), 1 << self.k); + + let num_parts = self.extended_len() >> self.k; + let mut extended_omega_factor = F::ONE; + (0..num_parts) + .map(|_| { + let part = self.coeff_to_extended_part(a.clone(), extended_omega_factor); + extended_omega_factor *= self.extended_omega; + part + }) + .collect() + } + + /// This takes us from several n-length coefficient vectors each into parts + /// of the extended evaluation domain. For example, for a polynomial with + /// size n, and an extended domain of size mn, we can compute all parts + /// independently, which are + /// `FFT(f(zeta * X), n)` + /// `FFT(f(zeta * extended_omega * X), n)` + /// ... + /// `FFT(f(zeta * extended_omega^{m-1} * X), n)` + pub fn batched_coeff_to_extended_parts( + &self, + a: &[Polynomial], + ) -> Vec>> { + assert_eq!(a[0].values.len(), 1 << self.k); + + let mut extended_omega_factor = F::ONE; + let num_parts = self.extended_len() >> self.k; + (0..num_parts) + .map(|_| { + let a_lagrange = a + .iter() + .map(|poly| self.coeff_to_extended_part(poly.clone(), extended_omega_factor)) + .collect(); + extended_omega_factor *= self.extended_omega; + a_lagrange + }) + .collect() + } + + /// This takes us from an n-length coefficient vector into a part of the + /// extended evaluation domain. For example, for a polynomial with size n, + /// and an extended domain of size mn, we can compute one of the m parts + /// separately, which is + /// `FFT(f(zeta * extended_omega_factor * X), n)` + /// where `extended_omega_factor` is `extended_omega^i` with `i` in `[0, m)`. + pub fn coeff_to_extended_part( + &self, + mut a: Polynomial, + extended_omega_factor: F, + ) -> Polynomial { + assert_eq!(a.values.len(), 1 << self.k); + + self.distribute_powers(&mut a.values, self.g_coset * extended_omega_factor); + let data = self.get_fft_data(a.len()); + best_fft(&mut a.values, self.omega, self.k, data, false); + + Polynomial { + values: a.values, + _marker: PhantomData, + } + } + /// Rotate the extended domain polynomial over the original domain. pub fn rotate_extended( &self, @@ -718,6 +432,72 @@ impl> EvaluationDomain { a.values } + /// This takes us from the a list of lagrange-based polynomials with + /// different degrees and gets their extended lagrange-based summation. + pub fn lagrange_vecs_to_extended( + &self, + mut a: Vec>>, + ) -> Polynomial { + let mut result_poly = if a[a.len() - 1].len() == 1 << (self.extended_k - self.k) { + self.lagrange_vec_to_extended(a.pop().unwrap()) + } else { + self.empty_extended() + }; + + // Transform from each cluster of lagrange representations to coeff representations. + let mut ifft_divisor = self.extended_ifft_divisor; + let mut omega_inv = self.extended_omega_inv; + { + let mut i = a.last().unwrap().len() << self.k; + while i < (1 << self.extended_k) { + ifft_divisor = ifft_divisor + ifft_divisor; + omega_inv = omega_inv * omega_inv; + i <<= 1; + } + } + + let mut result = vec![F::ZERO; 1 << self.extended_k as usize]; + for (i, a_parts) in a.into_iter().enumerate().rev() { + // transpose the values in parallel + assert_eq!(1 << i, a_parts.len()); + let mut a_poly: Vec = { + let mut transposed = vec![vec![F::ZERO; a_parts.len()]; self.n as usize]; + a_parts.into_iter().enumerate().for_each(|(j, p)| { + parallelize(&mut transposed, |transposed, start| { + for (transposed, p) in transposed.iter_mut().zip(p.values[start..].iter()) { + transposed[j] = *p; + } + }); + }); + transposed.into_iter().flatten().collect() + }; + + self.ifft(&mut a_poly, omega_inv, self.k + i as u32, ifft_divisor); + ifft_divisor = ifft_divisor + ifft_divisor; + omega_inv = omega_inv * omega_inv; + + parallelize(&mut result[0..(self.n << i) as usize], |result, start| { + for (other, current) in result.iter_mut().zip(a_poly[start..].iter()) { + *other += current; + } + }); + } + let data = self.get_fft_data(result.len()); + best_fft( + &mut result, + self.extended_omega, + self.extended_k, + data, + false, + ); + parallelize(&mut result_poly.values, |values, start| { + for (value, other) in values.iter_mut().zip(result[start..].iter()) { + *value += other; + } + }); + result_poly + } + /// This divides the polynomial (in the extended domain) by the vanishing /// polynomial of the $2^k$ size domain. pub fn divide_by_vanishing_poly( @@ -766,6 +546,19 @@ impl> EvaluationDomain { }); } + /// Given a slice of group elements `[a_0, a_1, a_2, ...]`, this returns + /// `[a_0, [c]a_1, [c^2]a_2, [c^3]a_3, [c^4]a_4, ...]`, + /// + fn distribute_powers(&self, a: &mut [F], c: F) { + parallelize(a, |a, index| { + let mut c_power = c.pow_vartime([index as u64]); + for a in a { + *a *= c_power; + c_power *= c; + } + }); + } + fn ifft(&self, a: &mut Vec, omega_inv: F, log_n: u32, divisor: F) { self.fft_inner(a, omega_inv, log_n, true); parallelize(a, |a, _| { @@ -777,16 +570,8 @@ impl> EvaluationDomain { } fn fft_inner(&self, a: &mut Vec, omega: F, log_n: u32, inverse: bool) { - if get_fft_mode() == 1 { - let fft_data = if a.len() == self.fft_data.n { - &self.fft_data - } else { - &self.extended_fft_data - }; - recursive_fft(fft_data, a, inverse); - } else { - best_fft(a, omega, log_n); - } + let fft_data = self.get_fft_data(a.len()); + best_fft(a, omega, log_n, fft_data, inverse) } /// Get the size of the domain @@ -903,6 +688,20 @@ impl> EvaluationDomain { omega: &self.omega, } } + + /// Get the private field `n` + pub fn get_n(&self) -> u64 { + self.n + } + + /// Get the private `fft_data` + pub fn get_fft_data(&self, l: usize) -> &FFTData { + if l == self.fft_data.get_n() { + &self.fft_data + } else { + &self.extended_fft_data + } + } } /// Represents the minimal parameters that determine an `EvaluationDomain`. @@ -986,35 +785,175 @@ fn test_l_i() { } #[test] -fn test_fft() { +fn test_coeff_to_extended_part() { use halo2curves::pasta::pallas::Scalar; use rand_core::OsRng; - fn get_degree() -> usize { - var("DEGREE") - .unwrap_or_else(|_| "8".to_string()) - .parse() - .expect("Cannot parse DEGREE env var as usize") + let domain = EvaluationDomain::::new(1, 3); + let rng = OsRng; + let mut poly = domain.empty_coeff(); + assert_eq!(poly.len(), 8); + for value in poly.iter_mut() { + *value = Scalar::random(rng); } - let k = get_degree() as u32; - let domain = EvaluationDomain::::new(1, k); - let n = domain.n as usize; + let want = domain.coeff_to_extended(&poly); + let got = { + let parts = domain.coeff_to_extended_parts(&poly); + domain.lagrange_vec_to_extended(parts) + }; + assert_eq!(want.values, got.values); +} + +#[test] +fn bench_coeff_to_extended_parts() { + use halo2curves::pasta::pallas::Scalar; + use rand_core::OsRng; + use std::time::Instant; - let input = vec![Scalar::random(OsRng); n]; - /*let mut input = vec![Scalar::zero(); n]; - for i in 0..n { - input[i] = Scalar::random(OsRng); - }*/ + let k = 20; + let domain = EvaluationDomain::::new(3, k); + let rng = OsRng; + let mut poly1 = domain.empty_coeff(); + assert_eq!(poly1.len(), 1 << k); - let mut a = input.clone(); - best_fft(&mut a, domain.omega, k); + for value in poly1.iter_mut() { + *value = Scalar::random(rng); + } - let mut b = input.clone(); - recursive_fft(&domain.fft_data, &mut b, false); + let poly2 = poly1.clone(); - for i in 0..n { - //println!("{}: {} {}", i, a[i], b[i]); - assert_eq!(a[i], b[i]); + let coeff_to_extended_timer = Instant::now(); + let _ = domain.coeff_to_extended(&poly1); + println!( + "domain.coeff_to_extended time: {}s", + coeff_to_extended_timer.elapsed().as_secs_f64() + ); + + let coeff_to_extended_parts_timer = Instant::now(); + let _ = domain.coeff_to_extended_parts(&poly2); + println!( + "domain.coeff_to_extended_parts time: {}s", + coeff_to_extended_parts_timer.elapsed().as_secs_f64() + ); +} + +#[test] +fn test_lagrange_vecs_to_extended() { + use halo2curves::pasta::pallas::Scalar; + use rand_core::OsRng; + + let rng = OsRng; + let domain = EvaluationDomain::::new(8, 3); + let mut poly_vec = vec![]; + let mut poly_lagrange_vecs = vec![]; + let mut want = domain.empty_extended(); + let mut omega = domain.extended_omega; + for i in (0..(domain.extended_k - domain.k + 1)).rev() { + let mut poly = vec![Scalar::zero(); (1 << i) * domain.n as usize]; + for value in poly.iter_mut() { + *value = Scalar::random(rng); + } + // poly under coeff representation. + poly_vec.push(poly.clone()); + // poly under lagrange vector representation. + let mut poly2 = poly.clone(); + let data = domain.get_fft_data(poly2.len()); + best_fft(&mut poly2, omega, i + domain.k, data, false); + let transposed_poly: Vec> = (0..(1 << i)) + .map(|j| { + let mut p = domain.empty_lagrange(); + for k in 0..domain.n { + p[k as usize] = poly2[j + (k as usize) * (1 << i)]; + } + p + }) + .collect(); + poly_lagrange_vecs.push(transposed_poly); + // poly under extended representation. + poly.resize(domain.extended_len(), Scalar::zero()); + let data = domain.get_fft_data(poly.len()); + best_fft( + &mut poly, + domain.extended_omega, + domain.extended_k, + data, + false, + ); + let poly = { + let mut p = domain.empty_extended(); + p.values = poly; + p + }; + want = want + &poly; + omega = omega * omega; } + + poly_lagrange_vecs.reverse(); + let got = domain.lagrange_vecs_to_extended(poly_lagrange_vecs); + assert_eq!(want.values, got.values); +} + +#[test] +fn bench_lagrange_vecs_to_extended() { + use halo2curves::pasta::pallas::Scalar; + use rand_core::OsRng; + use std::time::Instant; + + let rng = OsRng; + let domain = EvaluationDomain::::new(8, 10); + let mut poly_vec = vec![]; + let mut poly_lagrange_vecs = vec![]; + let mut poly_extended_vecs = vec![]; + let mut omega = domain.extended_omega; + + for i in (0..(domain.extended_k - domain.k + 1)).rev() { + let mut poly = vec![Scalar::zero(); (1 << i) * domain.n as usize]; + for value in poly.iter_mut() { + *value = Scalar::random(rng); + } + // poly under coeff representation. + poly_vec.push(poly.clone()); + // poly under lagrange vector representation. + let mut poly2 = poly.clone(); + let data = domain.get_fft_data(poly2.len()); + best_fft(&mut poly2, omega, i + domain.k, data, false); + let transposed_poly: Vec> = (0..(1 << i)) + .map(|j| { + let mut p = domain.empty_lagrange(); + for k in 0..domain.n { + p[k as usize] = poly2[j + (k as usize) * (1 << i)]; + } + p + }) + .collect(); + poly_lagrange_vecs.push(transposed_poly); + // poly under extended representation. + poly.resize(domain.extended_len(), Scalar::zero()); + let data = domain.get_fft_data(poly.len()); + best_fft( + &mut poly, + domain.extended_omega, + domain.extended_k, + data, + false, + ); + let poly = { + let mut p = domain.empty_extended(); + p.values = poly; + p + }; + poly_extended_vecs.push(poly); + omega = omega * omega; + } + + let want_timer = Instant::now(); + let _ = poly_extended_vecs + .iter() + .fold(domain.empty_extended(), |acc, p| acc + p); + println!("want time: {}s", want_timer.elapsed().as_secs_f64()); + poly_lagrange_vecs.reverse(); + let got_timer = Instant::now(); + let _ = domain.lagrange_vecs_to_extended(poly_lagrange_vecs); + println!("got time: {}s", got_timer.elapsed().as_secs_f64()); } From ad393e52d86a9d4a2b7c202d359914c119d7c36b Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 26 Oct 2023 17:46:52 -0700 Subject: [PATCH 2/6] fix: recursive FFT for lengths not `2^k, 2^extended_k` Also fixed some issue when multicore feature is not on. --- halo2_proofs/src/arithmetic.rs | 23 ---------- .../src/circuit/floor_planner/single_pass.rs | 1 + halo2_proofs/src/fft/recursive.rs | 2 +- halo2_proofs/src/plonk/keygen.rs | 6 +-- halo2_proofs/src/plonk/prover.rs | 21 +++++---- halo2_proofs/src/poly/domain.rs | 46 ++++++++++--------- 6 files changed, 41 insertions(+), 58 deletions(-) diff --git a/halo2_proofs/src/arithmetic.rs b/halo2_proofs/src/arithmetic.rs index 6c026bbcba..d0e19873c7 100644 --- a/halo2_proofs/src/arithmetic.rs +++ b/halo2_proofs/src/arithmetic.rs @@ -374,29 +374,6 @@ pub fn parallelize(v: &mu }); } -/// This simple utility function will parallelize an operation that is to be -/// performed over a mutable slice. -pub fn parallelize_count( - v: &mut [T], - num_threads: usize, - f: F, -) { - let n = v.len(); - let mut chunk = n / num_threads; - if chunk < num_threads { - chunk = n; - } - - multicore::scope(|scope| { - for (chunk_num, v) in v.chunks_mut(chunk).enumerate() { - let f = f.clone(); - scope.spawn(move |_| { - f(v, chunk_num); - }); - } - }); -} - pub fn log2_floor(num: usize) -> u32 { assert!(num > 0); diff --git a/halo2_proofs/src/circuit/floor_planner/single_pass.rs b/halo2_proofs/src/circuit/floor_planner/single_pass.rs index 824ad427f8..77a3960c47 100644 --- a/halo2_proofs/src/circuit/floor_planner/single_pass.rs +++ b/halo2_proofs/src/circuit/floor_planner/single_pass.rs @@ -402,6 +402,7 @@ mod tests { impl Circuit for MyCircuit { type Config = Column; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn params(&self) -> Self::Params {} fn without_witnesses(&self) -> Self { diff --git a/halo2_proofs/src/fft/recursive.rs b/halo2_proofs/src/fft/recursive.rs index 7973336c49..e905904463 100644 --- a/halo2_proofs/src/fft/recursive.rs +++ b/halo2_proofs/src/fft/recursive.rs @@ -95,7 +95,7 @@ impl FFTData { // Twiddles parallelize(twiddles, |twiddles, start| { let w_m = o; - let mut w = o.pow_vartime([start as u64, 0, 0, 0]); + let mut w = o.pow_vartime([start as u64]); for value in twiddles.iter_mut() { *value = w; w *= w_m; diff --git a/halo2_proofs/src/plonk/keygen.rs b/halo2_proofs/src/plonk/keygen.rs index a5b7e87e92..9b68a6f784 100644 --- a/halo2_proofs/src/plonk/keygen.rs +++ b/halo2_proofs/src/plonk/keygen.rs @@ -16,7 +16,7 @@ use super::{ use crate::{ arithmetic::{parallelize, CurveAffine}, circuit::Value, - multicore::{IntoParallelRefIterator, ParallelIterator}, + multicore::{IntoParallelIterator, ParallelIterator}, poly::{ batch_invert_assigned, commitment::{Blind, Params}, @@ -261,8 +261,8 @@ where .permutation .build_vk(params, &domain, &cs.permutation); - let fixed_commitments = fixed - .par_iter() + let fixed_commitments = (&fixed) + .into_par_iter() .map(|poly| params.commit_lagrange(poly, Blind::default()).to_affine()) .collect(); diff --git a/halo2_proofs/src/plonk/prover.rs b/halo2_proofs/src/plonk/prover.rs index 95bfa5ef07..1cea2761b4 100644 --- a/halo2_proofs/src/plonk/prover.rs +++ b/halo2_proofs/src/plonk/prover.rs @@ -8,7 +8,9 @@ use std::hash::Hash; use std::marker::PhantomData; use std::ops::RangeTo; -use crate::multicore::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +#[cfg(feature = "multicore")] +use crate::multicore::IndexedParallelIterator; +use crate::multicore::{IntoParallelIterator, ParallelIterator}; use std::{collections::HashMap, iter}; use super::{ @@ -296,12 +298,11 @@ where } } } else { - let instance_commitments_projective: Vec<_> = self - .instance_single - .instance_values - .par_iter() - .map(|poly| self.params.commit_lagrange(poly, Blind::default())) - .collect(); + let instance_commitments_projective: Vec<_> = + (&self.instance_single.instance_values) + .into_par_iter() + .map(|poly| self.params.commit_lagrange(poly, Blind::default())) + .collect(); let mut instance_commitments = vec![C::identity(); instance_commitments_projective.len()]; C::CurveExt::batch_normalize( @@ -338,9 +339,9 @@ where .iter() .map(|_| Blind(F::random(&mut self.rng))) .collect(); - let advice_commitments_projective: Vec<_> = advice_values - .par_iter() - .zip(blinds.par_iter()) + let advice_commitments_projective: Vec<_> = (&advice_values) + .into_par_iter() + .zip((&blinds).into_par_iter()) .map(|(poly, blind)| self.params.commit_lagrange(poly, *blind)) .collect(); let mut advice_commitments = vec![C::identity(); advice_commitments_projective.len()]; diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index 3e88e06c0a..749c2b8c43 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -11,7 +11,7 @@ use super::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation}; use group::ff::{BatchInvert, Field, WithSmallOrderMulGroup}; -use std::marker::PhantomData; +use std::{collections::HashMap, marker::PhantomData}; /// This structure contains precomputed constants and other details needed for /// performing operations on an evaluation domain of size $2^k$ and an extended @@ -34,8 +34,7 @@ pub struct EvaluationDomain { barycentric_weight: F, // Recursive stuff - fft_data: FFTData, - extended_fft_data: FFTData, + fft_data: HashMap>, } impl> EvaluationDomain { @@ -70,19 +69,22 @@ impl> EvaluationDomain { extended_omega = extended_omega.square(); } let extended_omega = extended_omega; - let mut extended_omega_inv = extended_omega; // Inversion computed later // Get omega, the 2^{k}'th root of unity (i.e. n'th root of unity) // The loop computes omega = extended_omega ^ {2 ^ (extended_k - k)} // = (omega^{2 ^ (S - extended_k)}) ^ {2 ^ (extended_k - k)} // = omega ^ {2 ^ (S - k)}. // Notice that omega ^ {2^k} = omega ^ {2^S} = 1. + let mut omegas = Vec::with_capacity((extended_k - k + 1) as usize); let mut omega = extended_omega; + omegas.push(omega); for _ in k..extended_k { omega = omega.square(); + omegas.push(omega); } let omega = omega; - let mut omega_inv = omega; // Inversion computed later + omegas.reverse(); + let mut omegas_inv = omegas.clone(); // Inversion computed later // We use zeta here because we know it generates a coset, and it's available // already. @@ -129,10 +131,18 @@ impl> EvaluationDomain { .chain(Some(&mut ifft_divisor)) .chain(Some(&mut extended_ifft_divisor)) .chain(Some(&mut barycentric_weight)) - .chain(Some(&mut extended_omega_inv)) - .chain(Some(&mut omega_inv)) + .chain(&mut omegas_inv) .batch_invert(); + let omega_inv = omegas_inv[0]; + let extended_omega_inv = *omegas_inv.last().unwrap(); + let mut fft_data = HashMap::new(); + for (i, (omega, omega_inv)) in omegas.into_iter().zip(omegas_inv).enumerate() { + let intermediate_k = k as usize + i; + let len = 1usize << intermediate_k; + fft_data.insert(len, FFTData::::new(len, omega, omega_inv)); + } + EvaluationDomain { n, k, @@ -148,12 +158,7 @@ impl> EvaluationDomain { extended_ifft_divisor, t_evaluations, barycentric_weight, - fft_data: FFTData::::new(n as usize, omega, omega_inv), - extended_fft_data: FFTData::::new( - (1 << extended_k) as usize, - extended_omega, - extended_omega_inv, - ), + fft_data, } } @@ -560,7 +565,9 @@ impl> EvaluationDomain { } fn ifft(&self, a: &mut Vec, omega_inv: F, log_n: u32, divisor: F) { - self.fft_inner(a, omega_inv, log_n, true); + let fft_data = self.get_fft_data(a.len()); + crate::fft::parallel::fft(a, omega_inv, log_n, fft_data, true); + // self.fft_inner(a, omega_inv, log_n, true); parallelize(a, |a, _| { for a in a { // Finish iFFT @@ -696,11 +703,9 @@ impl> EvaluationDomain { /// Get the private `fft_data` pub fn get_fft_data(&self, l: usize) -> &FFTData { - if l == self.fft_data.get_n() { - &self.fft_data - } else { - &self.extended_fft_data - } + self.fft_data + .get(&l) + .expect("log_2(l) must be in k..=extended_k") } } @@ -844,7 +849,7 @@ fn test_lagrange_vecs_to_extended() { use rand_core::OsRng; let rng = OsRng; - let domain = EvaluationDomain::::new(8, 3); + let domain = EvaluationDomain::::new(8, 10); let mut poly_vec = vec![]; let mut poly_lagrange_vecs = vec![]; let mut want = domain.empty_extended(); @@ -888,7 +893,6 @@ fn test_lagrange_vecs_to_extended() { want = want + &poly; omega = omega * omega; } - poly_lagrange_vecs.reverse(); let got = domain.lagrange_vecs_to_extended(poly_lagrange_vecs); assert_eq!(want.values, got.values); From 22a0944fbce8f757e3342cb19df0be702d01f79a Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 26 Oct 2023 18:08:36 -0700 Subject: [PATCH 3/6] fix: evaluation (cherry-pick https://github.com/scroll-tech/halo2-gpu/blob/a3019b047ad7d7119103f4d9df3fd33ce1429f95/halo2_proofs/src/plonk/evaluation.rs) --- .github/workflows/ci.yml | 4 +- halo2_proofs/src/plonk/evaluation.rs | 968 ++++++++------------------- halo2_proofs/src/poly/domain.rs | 26 + 3 files changed, 314 insertions(+), 684 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fecfc5872d..e3574cac8a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,9 +31,9 @@ jobs: with: override: false - name: Run examples - run: cargo run --example serialization + run: | + cargo run --example serialization cargo run --example shuffle - cargo run --example shuffle_api fmt: name: Rustfmt diff --git a/halo2_proofs/src/plonk/evaluation.rs b/halo2_proofs/src/plonk/evaluation.rs index 64de1bffa0..b33a6e3d8b 100644 --- a/halo2_proofs/src/plonk/evaluation.rs +++ b/halo2_proofs/src/plonk/evaluation.rs @@ -1,3 +1,4 @@ +#![allow(clippy::too_many_arguments)] use crate::multicore; use crate::plonk::{lookup, permutation, Any, ProvingKey}; use crate::poly::Basis; @@ -5,8 +6,7 @@ use crate::{ arithmetic::{parallelize, CurveAffine}, poly::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation}, }; -use group::ff::{Field, PrimeField, WithSmallOrderMulGroup}; -use itertools::Itertools; +use ff::{Field, PrimeField, WithSmallOrderMulGroup}; use super::{ConstraintSystem, Expression}; @@ -37,7 +37,9 @@ pub enum ValueSource { /// theta Theta(), /// y - Y(usize), + Y(), + /// Previous value + PreviousValue(), } impl Default for ValueSource { @@ -48,41 +50,39 @@ impl Default for ValueSource { impl ValueSource { /// Get the value for this source - #[allow(clippy::too_many_arguments)] pub fn get( &self, rotations: &[usize], constants: &[F], intermediates: &[F], - fixed_values: &[Option>], - advice_values: &[Option>], - instance_values: &[Option>], + fixed_values: &[Polynomial], + advice_values: &[Polynomial], + instance_values: &[Polynomial], challenges: &[F], - y_powers: &[F], beta: &F, gamma: &F, theta: &F, + y: &F, + previous_value: &F, ) -> F { match self { ValueSource::Constant(idx) => constants[*idx], ValueSource::Intermediate(idx) => intermediates[*idx], ValueSource::Fixed(column_index, rotation) => { - assert!(fixed_values[*column_index].is_some()); - fixed_values[*column_index].as_ref().unwrap()[rotations[*rotation]] + fixed_values[*column_index][rotations[*rotation]] } ValueSource::Advice(column_index, rotation) => { - assert!(advice_values[*column_index].is_some()); - advice_values[*column_index].as_ref().unwrap()[rotations[*rotation]] + advice_values[*column_index][rotations[*rotation]] } ValueSource::Instance(column_index, rotation) => { - assert!(instance_values[*column_index].is_some()); - instance_values[*column_index].as_ref().unwrap()[rotations[*rotation]] + instance_values[*column_index][rotations[*rotation]] } ValueSource::Challenge(index) => challenges[*index], ValueSource::Beta() => *beta, ValueSource::Gamma() => *gamma, ValueSource::Theta() => *theta, - ValueSource::Y(idx) => y_powers[*idx], + ValueSource::Y() => *y, + ValueSource::PreviousValue() => *previous_value, } } } @@ -110,20 +110,20 @@ pub enum Calculation { impl Calculation { /// Get the resulting value of this calculation - #[allow(clippy::too_many_arguments)] pub fn evaluate( &self, rotations: &[usize], constants: &[F], intermediates: &[F], - fixed_values: &[Option>], - advice_values: &[Option>], - instance_values: &[Option>], + fixed_values: &[Polynomial], + advice_values: &[Polynomial], + instance_values: &[Polynomial], challenges: &[F], - y_powers: &[F], beta: &F, gamma: &F, theta: &F, + y: &F, + previous_value: &F, ) -> F { let get_value = |value: &ValueSource| { value.get( @@ -134,10 +134,11 @@ impl Calculation { advice_values, instance_values, challenges, - y_powers, beta, gamma, theta, + y, + previous_value, ) }; match self { @@ -160,41 +161,13 @@ impl Calculation { } } -#[derive(Clone, Default, Debug)] -struct ConstraintCluster { - /// Used fixed columns in each cluster - used_fixed_columns: Vec, - /// Used instance columns in each cluster - used_instance_columns: Vec, - /// Used advice columns in each cluster - used_advice_columns: Vec, - /// Custom gates evalution - evaluator: GraphEvaluator, - /// The first index of constraints are being evaluated at in each cluster - first_constraint_idx: usize, - /// The last index of constraints are being evaluated at in each cluster - last_constraint_idx: usize, - /// The last value source - last_value_source: Option, -} - /// Evaluator #[derive(Clone, Default, Debug)] pub struct Evaluator { - /// list of constraint clusters - custom_gate_clusters: Vec>, - /// Number of custom gate constraints - num_custom_gate_constraints: usize, - /// Lookups evalution, degree, used instance and advice columns - #[allow(clippy::type_complexity)] - lookups: Vec<( - GraphEvaluator, - usize, - (Vec, Vec, Vec), - )>, - - /// Powers of y - num_y_powers: usize, + /// Custom gates evalution + pub custom_gates: GraphEvaluator, + /// Lookups evalution + pub lookups: Vec>, } /// GraphEvaluator @@ -228,122 +201,46 @@ pub struct CalculationInfo { pub target: usize, } -fn merge_unique(a: Vec, b: Vec) -> Vec { - let mut result = a; - result.extend(b); - result.into_iter().unique().collect() -} - impl Evaluator { /// Creates a new evaluation structure pub fn new(cs: &ConstraintSystem) -> Self { let mut ev = Evaluator::default(); - let mut constraint_idx = 0; - - // Compute the max cluster index - let quotient_poly_degree = (cs.degree() - 1) as u64; - let mut max_cluster_idx = 0; - while (1 << max_cluster_idx) < quotient_poly_degree { - max_cluster_idx += 1; - } - - ev.custom_gate_clusters - .resize(max_cluster_idx + 1, ConstraintCluster::default()); // Custom gates + let mut parts = Vec::new(); for gate in cs.gates.iter() { - for poly in gate.polynomials() { - constraint_idx += 1; - let cluster_idx = Self::compute_cluster_idx(poly.degree(), max_cluster_idx); - let custom_gate_cluster = &mut ev.custom_gate_clusters[cluster_idx]; - custom_gate_cluster.used_fixed_columns = merge_unique( - custom_gate_cluster.used_fixed_columns.clone(), - poly.extract_fixed(), - ); - custom_gate_cluster.used_instance_columns = merge_unique( - custom_gate_cluster.used_instance_columns.clone(), - poly.extract_instances(), - ); - custom_gate_cluster.used_advice_columns = merge_unique( - custom_gate_cluster.used_advice_columns.clone(), - poly.extract_advices(), - ); - let curr = custom_gate_cluster.evaluator.add_expression(poly); - if let Some(last) = custom_gate_cluster.last_value_source { - custom_gate_cluster.last_value_source = Some( - custom_gate_cluster - .evaluator - .add_calculation(Calculation::Horner( - last, - vec![curr], - ValueSource::Y( - constraint_idx - custom_gate_cluster.last_constraint_idx, - ), - )), - ); - } else { - assert_eq!(custom_gate_cluster.last_constraint_idx, 0); - custom_gate_cluster.last_value_source = Some(curr); - custom_gate_cluster.first_constraint_idx = constraint_idx; - } - custom_gate_cluster.last_constraint_idx = constraint_idx; - } + parts.extend( + gate.polynomials() + .iter() + .map(|poly| ev.custom_gates.add_expression(poly)), + ); } - - ev.num_custom_gate_constraints = constraint_idx; + ev.custom_gates.add_calculation(Calculation::Horner( + ValueSource::PreviousValue(), + parts, + ValueSource::Y(), + )); // Lookups for lookup in cs.lookups.iter() { - constraint_idx += 5; let mut graph = GraphEvaluator::default(); let mut evaluate_lc = |expressions: &Vec>| { - let mut max_degree = 0; - let mut used_fixed_columns = vec![]; - let mut used_instance_columns = vec![]; - let mut used_advice_columns = vec![]; let parts = expressions .iter() - .map(|expr| { - max_degree = max_degree.max(expr.degree()); - used_fixed_columns = - merge_unique(used_fixed_columns.clone(), expr.extract_fixed()); - used_instance_columns = - merge_unique(used_instance_columns.clone(), expr.extract_instances()); - used_advice_columns = - merge_unique(used_advice_columns.clone(), expr.extract_advices()); - graph.add_expression(expr) - }) + .map(|expr| graph.add_expression(expr)) .collect(); - ( - graph.add_calculation(Calculation::Horner( - ValueSource::Constant(0), - parts, - ValueSource::Theta(), - )), - max_degree, - used_fixed_columns, - used_instance_columns, - used_advice_columns, - ) + graph.add_calculation(Calculation::Horner( + ValueSource::Constant(0), + parts, + ValueSource::Theta(), + )) }; // Input coset - let ( - compressed_input_coset, - max_input_degree, - input_used_fixed, - input_used_instances, - input_used_advices, - ) = evaluate_lc(&lookup.input_expressions); + let compressed_input_coset = evaluate_lc(&lookup.input_expressions); // table coset - let ( - compressed_table_coset, - max_table_degree, - table_used_fixed, - table_used_instances, - table_used_advices, - ) = evaluate_lc(&lookup.table_expressions); + let compressed_table_coset = evaluate_lc(&lookup.table_expressions); // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) let right_gamma = graph.add_calculation(Calculation::Add( compressed_table_coset, @@ -354,26 +251,14 @@ impl Evaluator { ValueSource::Beta(), )); graph.add_calculation(Calculation::Mul(lc, right_gamma)); - ev.lookups.push(( - graph, - max_input_degree + max_table_degree, - ( - merge_unique(input_used_fixed, table_used_fixed), - merge_unique(input_used_instances, table_used_instances), - merge_unique(input_used_advices, table_used_advices), - ), - )); + + ev.lookups.push(graph); } - // Count the constraints in permutation - let num_sets = (cs.permutation.get_columns().len() + (cs.degree() - 3)) / (cs.degree() - 2); - constraint_idx += 1 + num_sets * 2; - ev.num_y_powers = constraint_idx + 10; ev } /// Evaluate h poly - #[allow(clippy::too_many_arguments)] pub(in crate::plonk) fn evaluate_h( &self, pk: &ProvingKey, @@ -385,7 +270,6 @@ impl Evaluator { gamma: C::ScalarExt, theta: C::ScalarExt, lookups: &[Vec>], - // shuffles: &[Vec>], permutations: &[permutation::prover::Committed], ) -> Polynomial { let domain = &pk.vk.domain; @@ -397,140 +281,98 @@ impl Evaluator { let one = C::ScalarExt::ONE; let p = &pk.vk.cs.permutation; let num_parts = domain.extended_len() >> domain.k(); - let num_clusters = (domain.extended_k() - domain.k() + 1) as usize; - - assert!(self.custom_gate_clusters.len() <= num_clusters); - - // Initialize the the powers of y and constraint counter - let mut y_powers = vec![C::ScalarExt::ONE; self.num_y_powers * instance_polys.len()]; - for i in 1..self.num_y_powers { - y_powers[i] = y_powers[i - 1] * y; - } - - let need_to_compute = |part_idx, cluster_idx| part_idx % (num_parts >> cluster_idx) == 0; - let compute_part_idx_in_cluster = - |part_idx, cluster_idx| part_idx >> (num_clusters - cluster_idx - 1); - - let mut value_part_clusters = Vec::new(); - value_part_clusters.resize(num_clusters, Vec::new()); - for (cluster_idx, cluster) in value_part_clusters - .iter_mut() - .enumerate() - .take(num_clusters) - { - cluster.resize(1 << cluster_idx, domain.empty_lagrange()); - } // Calculate the quotient polynomial for each part let mut current_extended_omega = one; - for part_idx in 0..num_parts { - let mut fixed: Vec>> = - vec![None; pk.fixed_polys.len()]; - let l0 = domain.coeff_to_extended_part(pk.l0.clone(), current_extended_omega); - let l_last = domain.coeff_to_extended_part(pk.l_last.clone(), current_extended_omega); - let l_active_row = - domain.coeff_to_extended_part(pk.l_active_row.clone(), current_extended_omega); - - let mut constraint_idx = 0; - let mut cluster_last_constraint_idx = vec![0; num_clusters]; + let value_parts: Vec> = (0..num_parts) + .map(|_| { + let fixed: Vec> = pk + .fixed_polys + .iter() + .map(|p| domain.coeff_to_extended_part(p.clone(), current_extended_omega)) + .collect(); + let fixed = &fixed[..]; + let l0 = domain.coeff_to_extended_part(pk.l0.clone(), current_extended_omega); + let l_last = + domain.coeff_to_extended_part(pk.l_last.clone(), current_extended_omega); + let l_active_row = + domain.coeff_to_extended_part(pk.l_active_row.clone(), current_extended_omega); - // Core expression evaluations - let num_threads = multicore::current_num_threads(); - for (((advice_polys, instance_polys), lookups), permutation) in advice_polys - .iter() - .zip(instance_polys.iter()) - .zip(lookups.iter()) - .zip(permutations.iter()) - { // Calculate the advice and instance cosets - let mut advice: Vec>> = - vec![None; advice_polys.len()]; - let mut instance: Vec>> = - vec![None; instance_polys.len()]; + let advice: Vec>> = advice_polys + .iter() + .map(|advice_polys| { + advice_polys + .iter() + .map(|poly| { + domain.coeff_to_extended_part(poly.clone(), current_extended_omega) + }) + .collect() + }) + .collect(); + let instance: Vec>> = instance_polys + .iter() + .map(|instance_polys| { + instance_polys + .iter() + .map(|poly| { + domain.coeff_to_extended_part(poly.clone(), current_extended_omega) + }) + .collect() + }) + .collect(); - // Custom gates - for (cluster_idx, custom_gates) in self.custom_gate_clusters.iter().enumerate() { - if !need_to_compute(part_idx, cluster_idx) - || custom_gates.last_value_source.is_none() - { - continue; - } - let values = &mut value_part_clusters[cluster_idx] - [compute_part_idx_in_cluster(part_idx, cluster_idx)]; - for fixed_idx in custom_gates.used_fixed_columns.iter() { - if fixed[*fixed_idx].is_none() { - fixed[*fixed_idx] = Some(domain.coeff_to_extended_part( - pk.fixed_polys[*fixed_idx].clone(), - current_extended_omega, - )); - } - } - for instance_idx in custom_gates.used_instance_columns.iter() { - if instance[*instance_idx].is_none() { - instance[*instance_idx] = Some(domain.coeff_to_extended_part( - instance_polys[*instance_idx].clone(), - current_extended_omega, - )); - } - } - for advice_idx in custom_gates.used_advice_columns.iter() { - if advice[*advice_idx].is_none() { - advice[*advice_idx] = Some(domain.coeff_to_extended_part( - advice_polys[*advice_idx].clone(), - current_extended_omega, - )); - } - } - let fixed_slice = &fixed[..]; - let advice_slice = &advice[..]; - let instance_slice = &instance[..]; - let y_power_slice = &y_powers[..]; - let y_power = y_powers[constraint_idx + custom_gates.first_constraint_idx - - cluster_last_constraint_idx[cluster_idx]]; + let mut values = domain.empty_lagrange(); + + // Core expression evaluations + let num_threads = multicore::current_num_threads(); + for (((advice, instance), lookups), permutation) in advice + .iter() + .zip(instance.iter()) + .zip(lookups.iter()) + .zip(permutations.iter()) + { + // Custom gates multicore::scope(|scope| { let chunk_size = (size + num_threads - 1) / num_threads; for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() { let start = thread_idx * chunk_size; scope.spawn(move |_| { - let mut eval_data = custom_gates.evaluator.instance(); + let mut eval_data = self.custom_gates.instance(); for (i, value) in values.iter_mut().enumerate() { let idx = start + i; - *value = *value * y_power - + custom_gates.evaluator.evaluate( - &mut eval_data, - fixed_slice, - advice_slice, - instance_slice, - challenges, - y_power_slice, - &beta, - &gamma, - &theta, - idx, - rot_scale, - isize, - ); + *value = self.custom_gates.evaluate( + &mut eval_data, + fixed, + advice, + instance, + challenges, + &beta, + &gamma, + &theta, + &y, + value, + idx, + rot_scale, + isize, + ); } }); } }); - // Update the constraint index - cluster_last_constraint_idx[cluster_idx] = - constraint_idx + custom_gates.last_constraint_idx; - } - constraint_idx += self.num_custom_gate_constraints; - - // Permutations - let sets = &permutation.sets; - if !sets.is_empty() { - let blinding_factors = pk.vk.cs.blinding_factors(); - let last_rotation = Rotation(-((blinding_factors + 1) as i32)); - let chunk_len = pk.vk.cs.degree() - 2; - let delta_start = beta * &C::Scalar::ZETA; - - let permutation_product_cosets: Vec> = - sets.iter() + // Permutations + let sets = &permutation.sets; + if !sets.is_empty() { + let blinding_factors = pk.vk.cs.blinding_factors(); + let last_rotation = Rotation(-((blinding_factors + 1) as i32)); + let chunk_len = pk.vk.cs.degree() - 2; + let delta_start = beta * &C::Scalar::ZETA; + + let permutation_product_cosets: Vec< + Polynomial, + > = sets + .iter() .map(|set| { domain.coeff_to_extended_part( set.permutation_product_poly.clone(), @@ -538,122 +380,6 @@ impl Evaluator { ) }) .collect(); - - let first_set_permutation_product_coset = - permutation_product_cosets.first().unwrap(); - let last_set_permutation_product_coset = - permutation_product_cosets.last().unwrap(); - - // Permutation constraints - constraint_idx += 1; - if need_to_compute(part_idx, 1) { - let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; - parallelize( - &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], - |values, start| { - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - // Enforce only for the first set. - // l_0(X) * (1 - z_0(X)) = 0, degree = 2 - *value = *value * y_power - + ((one - first_set_permutation_product_coset[idx]) - * l0[idx]); - } - }, - ); - cluster_last_constraint_idx[1] = constraint_idx; - } - - constraint_idx += 1; - if need_to_compute(part_idx, 2) { - let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; - parallelize( - &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], - |values, start| { - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - // Enforce only for the last set. - // l_last(X) * (z_l(X)^2 - z_l(X)) = 0, degree = 3 - *value = *value * y_power - + ((last_set_permutation_product_coset[idx] - * last_set_permutation_product_coset[idx] - - last_set_permutation_product_coset[idx]) - * l_last[idx]); - } - }, - ); - cluster_last_constraint_idx[2] = constraint_idx; - } - - constraint_idx += sets.len() - 1; - if need_to_compute(part_idx, 1) { - let y_skip = y_powers - [constraint_idx + 1 - sets.len() - cluster_last_constraint_idx[1]]; - parallelize( - &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], - |values, start| { - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - // Except for the first set, enforce. - // l_0(X) * (z_i(X) - z_{i-1}(\omega^(last) X)) = 0, degree = 2 - let r_last = - get_rotation_idx(idx, last_rotation.0, rot_scale, isize); - - *value *= y_skip; - - for (set_idx, permutation_product_coset) in - permutation_product_cosets.iter().enumerate() - { - if set_idx != 0 { - *value = *value * y - + ((permutation_product_coset[idx] - - permutation_product_cosets[set_idx - 1] - [r_last]) - * l0[idx]); - } - } - } - }, - ); - cluster_last_constraint_idx[1] = constraint_idx; - } - - constraint_idx += sets.len(); - let running_prod_cluster = - Self::compute_cluster_idx(2 + chunk_len, num_clusters - 1); - if need_to_compute(part_idx, running_prod_cluster) { - for column in p.columns.iter() { - match column.column_type() { - Any::Advice(_) => { - let advice = &mut advice[column.index()]; - if (*advice).is_none() { - *advice = Some(domain.coeff_to_extended_part( - advice_polys[column.index()].clone(), - current_extended_omega, - )); - } - } - Any::Instance => { - let instance = &mut instance[column.index()]; - if instance.is_none() { - *instance = Some(domain.coeff_to_extended_part( - instance_polys[column.index()].clone(), - current_extended_omega, - )); - } - } - Any::Fixed => { - let fixed = &mut fixed[column.index()]; - if fixed.is_none() { - *fixed = Some(domain.coeff_to_extended_part( - pk.fixed_polys[column.index()].clone(), - current_extended_omega, - )); - } - } - } - } - let permutation_cosets: Vec> = pk .permutation .polys @@ -663,301 +389,178 @@ impl Evaluator { }) .collect(); - let y_skip = y_powers[constraint_idx - - sets.len() - - cluster_last_constraint_idx[running_prod_cluster]]; - - parallelize( - &mut value_part_clusters[running_prod_cluster] - [compute_part_idx_in_cluster(part_idx, running_prod_cluster)], - |values, start| { - let mut beta_term = - current_extended_omega * omega.pow_vartime([start as u64]); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - let r_next = get_rotation_idx(idx, 1, rot_scale, isize); - - *value *= y_skip; - - // And for all the sets we enforce: - // (1 - (l_last(X) + l_blind(X))) * ( - // z_i(\omega X) \prod_j (p(X) + \beta s_j(X) + \gamma) - // - z_i(X) \prod_j (p(X) + \delta^j \beta X + \gamma) - // ), degree = 2 + chunk_len - let mut current_delta = delta_start * beta_term; - for ( - (columns, permutation_product_coset), - permutation_coset_chunk, - ) in p - .columns - .chunks(chunk_len) - .zip(permutation_product_cosets.iter()) - .zip(permutation_cosets.chunks(chunk_len)) - { - let mut left = permutation_product_coset[r_next]; - for (values, permutation) in columns - .iter() - .map(|&column| match column.column_type() { - Any::Advice(_) => { - advice[column.index()].as_ref().unwrap() - } - Any::Fixed => { - fixed[column.index()].as_ref().unwrap() - } - Any::Instance => { - instance[column.index()].as_ref().unwrap() - } - }) - .zip(permutation_coset_chunk.iter()) - { - left *= values[idx] + beta * permutation[idx] + gamma; - } - - let mut right = permutation_product_coset[idx]; - for values in columns.iter().map(|&column| { - match column.column_type() { - Any::Advice(_) => { - advice[column.index()].as_ref().unwrap() - } - Any::Fixed => { - fixed[column.index()].as_ref().unwrap() - } - Any::Instance => { - instance[column.index()].as_ref().unwrap() - } - } - }) { - right *= values[idx] + current_delta + gamma; - current_delta *= &C::Scalar::DELTA; - } - - *value = *value * y + ((left - right) * l_active_row[idx]); + let first_set_permutation_product_coset = + permutation_product_cosets.first().unwrap(); + let last_set_permutation_product_coset = + permutation_product_cosets.last().unwrap(); + + // Permutation constraints + parallelize(&mut values, |values, start| { + let mut beta_term = + current_extended_omega * omega.pow_vartime([start as u64]); + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let r_next = get_rotation_idx(idx, 1, rot_scale, isize); + let r_last = + get_rotation_idx(idx, last_rotation.0, rot_scale, isize); + + // Enforce only for the first set. + // l_0(X) * (1 - z_0(X)) = 0 + *value = *value * y + + ((one - first_set_permutation_product_coset[idx]) * l0[idx]); + // Enforce only for the last set. + // l_last(X) * (z_l(X)^2 - z_l(X)) = 0 + *value = *value * y + + ((last_set_permutation_product_coset[idx] + * last_set_permutation_product_coset[idx] + - last_set_permutation_product_coset[idx]) + * l_last[idx]); + // Except for the first set, enforce. + // l_0(X) * (z_i(X) - z_{i-1}(\omega^(last) X)) = 0 + for (set_idx, permutation_product_coset) in + permutation_product_cosets.iter().enumerate() + { + if set_idx != 0 { + *value = *value * y + + ((permutation_product_coset[idx] + - permutation_product_cosets[set_idx - 1][r_last]) + * l0[idx]); } - beta_term *= ω } - }, - ); - cluster_last_constraint_idx[running_prod_cluster] = constraint_idx; - } - } - - // Lookups - for (n, lookup) in lookups.iter().enumerate() { - let (lookup_evaluator, max_degree, used_columns) = &self.lookups[n]; - let running_prod_cluster = - Self::compute_cluster_idx(max_degree + 2, num_clusters - 1); - if !need_to_compute(part_idx, 1) - && !need_to_compute(part_idx, 2) - && !need_to_compute(part_idx, running_prod_cluster) - { - constraint_idx += 5; - continue; - } - - // Polynomials required for this lookup. - // Calculated here so these only have to be kept in memory for the short time - // they are actually needed. - let product_coset = pk.vk.domain.coeff_to_extended_part( - lookup.product_poly.clone(), - current_extended_omega, - ); - let permuted_input_coset = pk.vk.domain.coeff_to_extended_part( - lookup.permuted_input_poly.clone(), - current_extended_omega, - ); - let permuted_table_coset = pk.vk.domain.coeff_to_extended_part( - lookup.permuted_table_poly.clone(), - current_extended_omega, - ); + // And for all the sets we enforce: + // (1 - (l_last(X) + l_blind(X))) * ( + // z_i(\omega X) \prod_j (p(X) + \beta s_j(X) + \gamma) + // - z_i(X) \prod_j (p(X) + \delta^j \beta X + \gamma) + // ) + let mut current_delta = delta_start * beta_term; + for ( + (columns, permutation_product_coset), + permutation_coset_chunk, + ) in p + .columns + .chunks(chunk_len) + .zip(permutation_product_cosets.iter()) + .zip(permutation_cosets.chunks(chunk_len)) + { + let mut left = permutation_product_coset[r_next]; + for (values, permutation) in columns + .iter() + .map(|&column| match column.column_type() { + Any::Advice(_) => &advice[column.index()], + Any::Fixed => &fixed[column.index()], + Any::Instance => &instance[column.index()], + }) + .zip(permutation_coset_chunk.iter()) + { + left *= values[idx] + beta * permutation[idx] + gamma; + } - // Lookup constraints - constraint_idx += 1; - if need_to_compute(part_idx, 1) { - let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; + let mut right = permutation_product_coset[idx]; + for values in + columns.iter().map(|&column| match column.column_type() { + Any::Advice(_) => &advice[column.index()], + Any::Fixed => &fixed[column.index()], + Any::Instance => &instance[column.index()], + }) + { + right *= values[idx] + current_delta + gamma; + current_delta *= &C::Scalar::DELTA; + } - parallelize( - &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], - |values, start| { - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - // l_0(X) * (1 - z(X)) = 0, degree = 2 - *value = - *value * y_power + ((one - product_coset[idx]) * l0[idx]); + *value = *value * y + ((left - right) * l_active_row[idx]); } - }, - ); - cluster_last_constraint_idx[1] = constraint_idx; + beta_term *= ω + } + }); } - constraint_idx += 1; - if need_to_compute(part_idx, 2) { - let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; - parallelize( - &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], - |values, start| { - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - // l_last(X) * (z(X)^2 - z(X)) = 0, degree = 3 - *value = *value * y_power - + ((product_coset[idx] * product_coset[idx] - - product_coset[idx]) - * l_last[idx]); - } - }, + // Lookups + for (n, lookup) in lookups.iter().enumerate() { + // Polynomials required for this lookup. + // Calculated here so these only have to be kept in memory for the short time + // they are actually needed. + let product_coset = pk.vk.domain.coeff_to_extended_part( + lookup.product_poly.clone(), + current_extended_omega, ); - cluster_last_constraint_idx[2] = constraint_idx; - } - constraint_idx += 1; - if need_to_compute(part_idx, running_prod_cluster) { - for fixed_column in used_columns.0.iter() { - let fixed = &mut fixed[*fixed_column]; - if fixed.is_none() { - *fixed = Some(domain.coeff_to_extended_part( - pk.fixed_polys[*fixed_column].clone(), - current_extended_omega, - )); - } - } - for instance_column in used_columns.1.iter() { - let instance = &mut instance[*instance_column]; - if instance.is_none() { - *instance = Some(domain.coeff_to_extended_part( - instance_polys[*instance_column].clone(), - current_extended_omega, - )); - } - } - - for advice_column in used_columns.2.iter() { - let advice = &mut advice[*advice_column]; - if (*advice).is_none() { - *advice = Some(domain.coeff_to_extended_part( - advice_polys[*advice_column].clone(), - current_extended_omega, - )); - } - } - - let y_power = y_powers - [constraint_idx - cluster_last_constraint_idx[running_prod_cluster]]; - let fixed_slice = &fixed[..]; - let advice_slice = &advice[..]; - let instance_slice = &instance[..]; - let y_power_slice = &y_powers[..]; - parallelize( - &mut value_part_clusters[running_prod_cluster] - [compute_part_idx_in_cluster(part_idx, running_prod_cluster)], - |values, start| { - let mut eval_data = lookup_evaluator.instance(); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - let table_value = lookup_evaluator.evaluate( - &mut eval_data, - fixed_slice, - advice_slice, - instance_slice, - challenges, - y_power_slice, - &beta, - &gamma, - &theta, - idx, - rot_scale, - isize, - ); - - let r_next = get_rotation_idx(idx, 1, rot_scale, isize); - - // (1 - (l_last(X) + l_blind(X))) * ( - // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) - // - z(X) (\theta^{m-1} a_0(X) + ... + a_{m-1}(X) + \beta) - // (\theta^{m-1} s_0(X) + ... + s_{m-1}(X) + \gamma) - // ) = 0, degree = 2 + max(deg(a)) + max(deg(s)) - *value = *value * y_power - + ((product_coset[r_next] - * (permuted_input_coset[idx] + beta) - * (permuted_table_coset[idx] + gamma) - - product_coset[idx] * table_value) - * l_active_row[idx]); - } - }, + let permuted_input_coset = pk.vk.domain.coeff_to_extended_part( + lookup.permuted_input_poly.clone(), + current_extended_omega, ); - cluster_last_constraint_idx[running_prod_cluster] = constraint_idx; - } - - constraint_idx += 1; - if need_to_compute(part_idx, 1) { - let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; - parallelize( - &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], - |values, start| { - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - let a_minus_s = - permuted_input_coset[idx] - permuted_table_coset[idx]; - // Check that the first values in the permuted input expression and permuted - // fixed expression are the same. - // l_0(X) * (a'(X) - s'(X)) = 0, degree = 2 - *value = *value * y_power + (a_minus_s * l0[idx]); - } - }, + let permuted_table_coset = pk.vk.domain.coeff_to_extended_part( + lookup.permuted_table_poly.clone(), + current_extended_omega, ); - cluster_last_constraint_idx[1] = constraint_idx; - } - - constraint_idx += 1; - if need_to_compute(part_idx, 2) { - let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; - parallelize( - &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], - |values, start| { - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - let r_prev = get_rotation_idx(idx, -1, rot_scale, isize); - // Check that each value in the permuted lookup input expression is either - // equal to the value above it, or the value at the same index in the - // permuted table expression. - // (1 - (l_last + l_blind)) * (a′(X) − s′(X))⋅(a′(X) − a′(\omega^{-1} X)) = 0, degree = 3 - let a_minus_s = - permuted_input_coset[idx] - permuted_table_coset[idx]; - *value = *value * y_power - + (a_minus_s - * (permuted_input_coset[idx] - - permuted_input_coset[r_prev]) - * l_active_row[idx]); - } - }, - ); - cluster_last_constraint_idx[2] = constraint_idx; - } - } - } - // Align the constraints by different powers of y. - for (i, cluster) in value_part_clusters.iter_mut().enumerate() { - if need_to_compute(part_idx, i) && cluster_last_constraint_idx[i] > 0 { - let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[i]]; - parallelize( - &mut cluster[compute_part_idx_in_cluster(part_idx, i)], - |values, _| { - for value in values.iter_mut() { - *value *= y_power; + // Lookup constraints + parallelize(&mut values, |values, start| { + let lookup_evaluator = &self.lookups[n]; + let mut eval_data = lookup_evaluator.instance(); + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + + let table_value = lookup_evaluator.evaluate( + &mut eval_data, + fixed, + advice, + instance, + challenges, + &beta, + &gamma, + &theta, + &y, + &C::ScalarExt::ZERO, + idx, + rot_scale, + isize, + ); + + let r_next = get_rotation_idx(idx, 1, rot_scale, isize); + let r_prev = get_rotation_idx(idx, -1, rot_scale, isize); + + let a_minus_s = + permuted_input_coset[idx] - permuted_table_coset[idx]; + // l_0(X) * (1 - z(X)) = 0 + *value = *value * y + ((one - product_coset[idx]) * l0[idx]); + // l_last(X) * (z(X)^2 - z(X)) = 0 + *value = *value * y + + ((product_coset[idx] * product_coset[idx] + - product_coset[idx]) + * l_last[idx]); + // (1 - (l_last(X) + l_blind(X))) * ( + // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) + // - z(X) (\theta^{m-1} a_0(X) + ... + a_{m-1}(X) + \beta) + // (\theta^{m-1} s_0(X) + ... + s_{m-1}(X) + \gamma) + // ) = 0 + *value = *value * y + + ((product_coset[r_next] + * (permuted_input_coset[idx] + beta) + * (permuted_table_coset[idx] + gamma) + - product_coset[idx] * table_value) + * l_active_row[idx]); + // Check that the first values in the permuted input expression and permuted + // fixed expression are the same. + // l_0(X) * (a'(X) - s'(X)) = 0 + *value = *value * y + (a_minus_s * l0[idx]); + // Check that each value in the permuted lookup input expression is either + // equal to the value above it, or the value at the same index in the + // permuted table expression. + // (1 - (l_last + l_blind)) * (a′(X) − s′(X))⋅(a′(X) − a′(\omega^{-1} X)) = 0 + *value = *value * y + + (a_minus_s + * (permuted_input_coset[idx] + - permuted_input_coset[r_prev]) + * l_active_row[idx]); } - }, - ); + }); + } } - } - current_extended_omega *= extended_omega; - } - domain.lagrange_vecs_to_extended(value_part_clusters) - } + current_extended_omega *= extended_omega; + values + }) + .collect(); - fn compute_cluster_idx(degree: usize, max_cluster_idx: usize) -> usize { - let mut idx = (31 - (degree as u32).leading_zeros()) as usize; - if 1 << idx < degree { - idx += 1; - } - std::cmp::min(max_cluster_idx, idx) + domain.extended_from_lagrange_vec(value_parts) } } @@ -1136,18 +739,18 @@ impl GraphEvaluator { } } - #[allow(clippy::too_many_arguments)] pub fn evaluate( &self, data: &mut EvaluationData, - fixed: &[Option>], - advice: &[Option>], - instance: &[Option>], + fixed: &[Polynomial], + advice: &[Polynomial], + instance: &[Polynomial], challenges: &[C::ScalarExt], - y_powers: &[C::ScalarExt], beta: &C::ScalarExt, gamma: &C::ScalarExt, theta: &C::ScalarExt, + y: &C::ScalarExt, + previous_value: &C::ScalarExt, idx: usize, rot_scale: i32, isize: i32, @@ -1167,10 +770,11 @@ impl GraphEvaluator { advice, instance, challenges, - y_powers, beta, gamma, theta, + y, + previous_value, ); } diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index 749c2b8c43..f796771781 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -224,6 +224,32 @@ impl> EvaluationDomain { } } + /// Obtains a polynomial in ExtendedLagrange form when given a vector of + /// Lagrange polynomials with total size `extended_n`; panics if the + /// provided vector is the wrong length. + pub fn extended_from_lagrange_vec( + &self, + values: Vec>, + ) -> Polynomial { + assert_eq!(values.len(), self.extended_len() >> self.k); + assert_eq!(values[0].len(), self.n as usize); + + // transpose the values in parallel + let mut transposed = vec![vec![F::ZERO; values.len()]; self.n as usize]; + values.into_iter().enumerate().for_each(|(i, p)| { + parallelize(&mut transposed, |transposed, start| { + for (transposed, p) in transposed.iter_mut().zip(p.values[start..].iter()) { + transposed[i] = *p; + } + }); + }); + + Polynomial { + values: transposed.into_iter().flatten().collect(), + _marker: PhantomData, + } + } + /// Returns an empty (zero) polynomial in the coefficient basis pub fn empty_coeff(&self) -> Polynomial { Polynomial { From ac53f70fb19c9363ee009fc5073a837b18c77476 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 26 Oct 2023 18:54:24 -0700 Subject: [PATCH 4/6] fix: `g_to_lagrange` uses inverse fft --- halo2_proofs/src/arithmetic.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2_proofs/src/arithmetic.rs b/halo2_proofs/src/arithmetic.rs index d0e19873c7..29b5433541 100644 --- a/halo2_proofs/src/arithmetic.rs +++ b/halo2_proofs/src/arithmetic.rs @@ -238,7 +238,7 @@ pub fn g_to_lagrange(g_projective: Vec, k: u32) - let n = g_lagrange_projective.len(); let fft_data = FFTData::new(n, omega, omega_inv); - best_fft(&mut g_lagrange_projective, omega_inv, k, &fft_data, false); + best_fft(&mut g_lagrange_projective, omega_inv, k, &fft_data, true); parallelize(&mut g_lagrange_projective, |g, _| { for g in g.iter_mut() { *g *= n_inv; From 076711d6f1f55e3366a6d835a1bd6a3f5d93c097 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 13 Nov 2023 15:56:05 -0800 Subject: [PATCH 5/6] Bump version since this is a breaking change to pk --- halo2_proofs/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index 163d93e98b..ff9d358fb0 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2_proofs" -version = "0.3.1" +version = "0.4.0" authors = [ "Sean Bowe ", "Ying Tong Lai ", From fa995aaa9f6a6d65b33b82883d1438a8bb330fb8 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:14:46 -0800 Subject: [PATCH 6/6] chore: update halo2curves version --- halo2_proofs/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index fcc77b9d48..449dfe6465 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -55,13 +55,13 @@ name = "fft" harness = false [dependencies] -itertools = "0.10" +itertools = "0.11" backtrace = { version = "0.3", optional = true } crossbeam = "0.8" ff = "0.13" group = "0.13" pairing = "0.23" -halo2curves = { package = "halo2curves-axiom", version = "0.4.1", default-features = false, features = ["bits", "bn256-table", "derive_serde"] } +halo2curves = { package = "halo2curves-axiom", version = "0.4.2", default-features = false, features = ["bits", "bn256-table", "derive_serde"] } rand = "0.8" rand_core = { version = "0.6", default-features = false} tracing = "0.1"