Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add new parallel implementation for permute_expression_pair #189

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ harness = false
name = "dev_lookup"
harness = false

[[bench]]
name = "bench_lookup"
harness = false

[[bench]]
name = "fft"
harness = false
Expand Down
137 changes: 137 additions & 0 deletions halo2_proofs/benches/bench_lookup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#[macro_use]
extern crate criterion;

use ff::{Field, PrimeField};
use halo2_proofs::circuit::{Layouter, SimpleFloorPlanner, Value};
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
use halo2_proofs::poly::ipa::multiopen::ProverIPA;
use halo2_proofs::poly::Rotation;
use halo2_proofs::transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer};
use halo2curves::pasta::{pallas, EqAffine};
use rand_core::OsRng;

use std::marker::PhantomData;

use criterion::{BenchmarkId, Criterion};

fn criterion_benchmark(c: &mut Criterion) {
#[derive(Clone, Default)]
struct MyCircuit<F: Field> {
k: usize,
_marker: PhantomData<F>,
}

#[derive(Clone)]
struct MyConfig {
selector: Selector,
table: TableColumn,
advice: Column<Advice>,
}

impl<F: PrimeField> Circuit<F> for MyCircuit<F> {
type Config = MyConfig;
type FloorPlanner = SimpleFloorPlanner;
#[cfg(feature = "circuit-params")]
type Params = ();

fn without_witnesses(&self) -> Self {
Self::default()
}

fn configure(meta: &mut ConstraintSystem<F>) -> MyConfig {
let config = MyConfig {
selector: meta.complex_selector(),
table: meta.lookup_table_column(),
advice: meta.advice_column(),
};

meta.lookup("lookup", |meta| {
let selector = meta.query_selector(config.selector);
let not_selector = Expression::Constant(F::ONE) - selector.clone();
let advice = meta.query_advice(config.advice, Rotation::cur());
vec![(selector * advice + not_selector, config.table)]
});

config
}

fn synthesize(
&self,
config: MyConfig,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter.assign_table(
|| "lookup table",
|mut table| {
for row in 0u64..(1 << (self.k - 1)) {
table.assign_cell(
|| format!("row {}", row),
config.table,
row as usize,
|| Value::known(F::from(row)),
)?;
}

Ok(())
},
)?;

layouter.assign_region(
|| "assign values",
|mut region| {
for offset in 0u64..(1 << self.k) - 20 {
config.selector.enable(&mut region, offset as usize)?;
region.assign_advice(
|| format!("offset {}", offset),
config.advice,
offset as usize,
|| Value::known(F::from(offset >> 1)),
)?;
}

Ok(())
},
)
}
}

let k_range = 14..=18;

let mut prover_group = c.benchmark_group("bench-lookup");
prover_group.sample_size(10);
for k in k_range {
let circuit = MyCircuit::<pallas::Base> {
k: k as usize,
_marker: PhantomData,
};
let params = ParamsIPA::<EqAffine>::new(k);
let vk = keygen_vk(&params, &circuit).unwrap();
let pk = keygen_pk(&params, vk, &circuit).unwrap();
prover_group.bench_with_input(
BenchmarkId::from_parameter(k),
&(params, pk),
|b, (params, pk)| {
b.iter(|| {
let mut transcript = Blake2bWrite::<_, _, Challenge255<EqAffine>>::init(vec![]);
let rng = OsRng;
create_proof::<IPACommitmentScheme<EqAffine>, ProverIPA<EqAffine>, _, _, _, _>(
params,
pk,
&[circuit.clone()],
&[&[]],
rng,
&mut transcript,
)
.unwrap();
transcript.finalize();
});
},
);
}
prover_group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
129 changes: 128 additions & 1 deletion halo2_proofs/src/plonk/lookup/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use group::{
Curve,
};
use rand_core::RngCore;
use rayon::prelude::*;
use std::{any::TypeId, convert::TryInto, num::ParseIntError, ops::Index};
use std::{
collections::BTreeMap,
Expand Down Expand Up @@ -390,6 +391,132 @@ type ExpressionPair<F> = (Polynomial<F, LagrangeCoeff>, Polynomial<F, LagrangeCo
/// that has the corresponding value in S'.
/// This method returns (A', S') if no errors are encountered.
fn permute_expression_pair<'params, C: CurveAffine, P: Params<'params, C>, R: RngCore>(
pk: &ProvingKey<C>,
params: &P,
domain: &EvaluationDomain<C::Scalar>,
rng: R,
input_expression: &Polynomial<C::Scalar, LagrangeCoeff>,
table_expression: &Polynomial<C::Scalar, LagrangeCoeff>,
) -> Result<ExpressionPair<C::Scalar>, Error> {
// heuristic on when multi-threading isn't worth it
// for now it seems like multi-threading is often worth it
/*
let num_threads = rayon::current_num_threads();
if params.n() < (num_threads as u64) << 10 {
return permute_expression_pair_seq(
pk,
params,
domain,
rng,
input_expression,
table_expression,
);
}*/
let start = std::time::Instant::now();
let res =
permute_expression_pair_par(pk, params, domain, rng, input_expression, table_expression);
dbg!(start.elapsed());
res
}

fn permute_expression_pair_par<'params, C: CurveAffine, P: Params<'params, C>, R: RngCore>(
pk: &ProvingKey<C>,
params: &P,
domain: &EvaluationDomain<C::Scalar>,
mut rng: R,
input_expression: &Polynomial<C::Scalar, LagrangeCoeff>,
table_expression: &Polynomial<C::Scalar, LagrangeCoeff>,
) -> Result<ExpressionPair<C::Scalar>, Error> {
let num_threads = rayon::current_num_threads();
let blinding_factors = pk.vk.cs.blinding_factors();
let usable_rows = params.n() as usize - (blinding_factors + 1);

let input_expression = &input_expression[0..usable_rows];

// count input_expression unique values using a HashMap, using rayon parallel fold+reduce
let capacity = usable_rows / num_threads + 1;
let input_uniques: BTreeMap<C::Scalar, usize> = input_expression
.par_iter()
.fold(BTreeMap::new, |mut acc, coeff| {
*acc.entry(*coeff).or_insert(0) += 1;
acc
})
.reduce_with(|mut m1, m2| {
m2.into_iter().for_each(|(k, v)| {
*m1.entry(k).or_insert(0) += v;
});
m1
})
.unwrap();

let input_unique_ranges = input_uniques
.par_iter()
.fold(
|| Vec::with_capacity(capacity),
|mut input_ranges, (&coeff, &count)| {
if input_ranges.is_empty() {
input_ranges.push((coeff, 0..count));
} else {
let prev_end = input_ranges.last().unwrap().1.end;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're already checking for empty range, then we should unwrap

Suggested change
let prev_end = input_ranges.last().unwrap().1.end;
let prev_end = unsafe{ input_ranges.last().unwrap_unchecked().1.end};

input_ranges.push((coeff, prev_end..prev_end + count));
}
input_ranges
},
)
.reduce_with(|r1, mut r2| {
let r1_end = r1.last().unwrap().1.end;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how we know we will never panic here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each r1 is the result of the previous fold step. As long as the fold is over nonempty iterator, the output should be nonempty. So r1 is nonempty unless input_uniques is empty I believe.

r2.par_iter_mut().for_each(|r2| {
r2.1.start += r1_end;
r2.1.end += r1_end;
});
[r1, r2].concat()
})
.unwrap();

let mut sorted_table_coeffs = table_expression[0..usable_rows].to_vec();
sorted_table_coeffs.par_sort();

let leftover_table_coeffs: Vec<C::Scalar> = sorted_table_coeffs
.par_iter()
.enumerate()
.filter_map(|(i, coeff)| {
((i != 0 && coeff == &sorted_table_coeffs[i - 1]) || !input_uniques.contains_key(coeff))
.then_some(*coeff)
})
.collect();

// didn't want to bother with Sync rng or anything so just do this part sequentially
let blinding: Vec<(C::Scalar, C::Scalar)> = (usable_rows..params.n() as usize)
.into_iter()
.map(|_| (C::Scalar::random(&mut rng), C::Scalar::random(&mut rng)))
.collect();
Comment on lines +488 to +492
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm We can maybe file an issue in case we see this being critical.

let (permuted_input_expression, permuted_table_coeffs): (Vec<_>, Vec<_>) = input_unique_ranges
.into_par_iter()
.enumerate()
.flat_map(|(i, (coeff, range))| {
// subtract off the number of rows in table rows that correspond to input uniques
let leftover_range_start = range.start - i;
let leftover_range_end = range.end - i - 1;
[(coeff, coeff)].into_par_iter().chain(
leftover_table_coeffs[leftover_range_start..leftover_range_end]
.par_iter()
.map(move |leftover_table_coeff| (coeff, *leftover_table_coeff)),
)
})
.chain(blinding)
.unzip();

assert_eq!(permuted_input_expression.len(), params.n() as usize);
assert_eq!(permuted_table_coeffs.len(), params.n() as usize);

Ok((
domain.lagrange_from_vec(permuted_input_expression),
domain.lagrange_from_vec(permuted_table_coeffs),
))
}

#[allow(dead_code)]
fn permute_expression_pair_seq<'params, C: CurveAffine, P: Params<'params, C>, R: RngCore>(
pk: &ProvingKey<C>,
params: &P,
domain: &EvaluationDomain<C::Scalar>,
Expand All @@ -404,7 +531,7 @@ fn permute_expression_pair<'params, C: CurveAffine, P: Params<'params, C>, R: Rn
permuted_input_expression.truncate(usable_rows);

// Sort input lookup expression values
permuted_input_expression.sort();
permuted_input_expression.par_sort();

// A BTreeMap of each unique element in the table expression and its count
let mut leftover_table_map: BTreeMap<C::Scalar, u32> = table_expression
Expand Down
16 changes: 15 additions & 1 deletion halo2_proofs/src/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use group::ff::{BatchInvert, Field};
use std::fmt::Debug;
use std::io;
use std::marker::PhantomData;
use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Mul, RangeFrom, RangeFull, Sub};
use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Mul, Range, RangeFrom, RangeFull, Sub};

/// Generic commitment scheme structures
pub mod commitment;
Expand Down Expand Up @@ -112,6 +112,20 @@ impl<F, B> IndexMut<RangeFull> for Polynomial<F, B> {
}
}

impl<F, B> Index<Range<usize>> for Polynomial<F, B> {
type Output = [F];

fn index(&self, index: Range<usize>) -> &[F] {
self.values.index(index)
}
}

impl<F, B> IndexMut<Range<usize>> for Polynomial<F, B> {
fn index_mut(&mut self, index: Range<usize>) -> &mut [F] {
self.values.index_mut(index)
}
}

impl<F, B> Deref for Polynomial<F, B> {
type Target = [F];

Expand Down
Loading