Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add new parallel implementation for permute_expression_pair #189

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 124 additions & 1 deletion halo2_proofs/src/plonk/lookup/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use group::{
Curve,
};
use rand_core::RngCore;
use rayon::prelude::*;
use std::{any::TypeId, convert::TryInto, num::ParseIntError, ops::Index};
use std::{
collections::BTreeMap,
Expand Down Expand Up @@ -390,6 +391,128 @@ type ExpressionPair<F> = (Polynomial<F, LagrangeCoeff>, Polynomial<F, LagrangeCo
/// that has the corresponding value in S'.
/// This method returns (A', S') if no errors are encountered.
fn permute_expression_pair<'params, C: CurveAffine, P: Params<'params, C>, R: RngCore>(
pk: &ProvingKey<C>,
params: &P,
domain: &EvaluationDomain<C::Scalar>,
rng: R,
input_expression: &Polynomial<C::Scalar, LagrangeCoeff>,
table_expression: &Polynomial<C::Scalar, LagrangeCoeff>,
) -> Result<ExpressionPair<C::Scalar>, Error> {
// heuristic on when multi-threading isn't worth it
// for now it seems like multi-threading is often worth it
/*
let num_threads = rayon::current_num_threads();
if params.n() < (num_threads as u64) << 10 {
return permute_expression_pair_seq(
pk,
params,
domain,
rng,
input_expression,
table_expression,
);
}*/
permute_expression_pair_par(pk, params, domain, rng, input_expression, table_expression)
}

fn permute_expression_pair_par<'params, C: CurveAffine, P: Params<'params, C>, R: RngCore>(
pk: &ProvingKey<C>,
params: &P,
domain: &EvaluationDomain<C::Scalar>,
mut rng: R,
input_expression: &Polynomial<C::Scalar, LagrangeCoeff>,
table_expression: &Polynomial<C::Scalar, LagrangeCoeff>,
) -> Result<ExpressionPair<C::Scalar>, Error> {
let num_threads = rayon::current_num_threads();
let blinding_factors = pk.vk.cs.blinding_factors();
let usable_rows = params.n() as usize - (blinding_factors + 1);

let input_expression = &input_expression[0..usable_rows];

// count input_expression unique values using a HashMap, using rayon parallel fold+reduce
let capacity = usable_rows / num_threads + 1;
let input_uniques: BTreeMap<C::Scalar, usize> = input_expression
.par_iter()
.fold(BTreeMap::new, |mut acc, coeff| {
*acc.entry(*coeff).or_insert(0) += 1;
acc
})
.reduce_with(|mut m1, m2| {
m2.into_iter().for_each(|(k, v)| {
*m1.entry(k).or_insert(0) += v;
});
m1
})
.unwrap();

let input_unique_ranges = input_uniques
.par_iter()
.fold(
|| Vec::with_capacity(capacity),
|mut input_ranges, (&coeff, &count)| {
if input_ranges.is_empty() {
input_ranges.push((coeff, 0..count));
} else {
let prev_end = input_ranges.last().unwrap().1.end;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're already checking for empty range, then we should unwrap

Suggested change
let prev_end = input_ranges.last().unwrap().1.end;
let prev_end = unsafe{ input_ranges.last().unwrap_unchecked().1.end};

input_ranges.push((coeff, prev_end..prev_end + count));
}
input_ranges
},
)
.reduce_with(|r1, mut r2| {
let r1_end = r1.last().unwrap().1.end;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how we know we will never panic here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each r1 is the result of the previous fold step. As long as the fold is over nonempty iterator, the output should be nonempty. So r1 is nonempty unless input_uniques is empty I believe.

r2.par_iter_mut().for_each(|r2| {
r2.1.start += r1_end;
r2.1.end += r1_end;
});
[r1, r2].concat()
})
.unwrap();

let mut sorted_table_coeffs = table_expression[0..usable_rows].to_vec();
sorted_table_coeffs.par_sort();

let leftover_table_coeffs: Vec<C::Scalar> = sorted_table_coeffs
.par_iter()
.enumerate()
.filter_map(|(i, coeff)| {
((i != 0 && coeff == &sorted_table_coeffs[i - 1]) || !input_uniques.contains_key(coeff))
.then_some(*coeff)
})
.collect();

// didn't want to bother with Sync rng or anything so just do this part sequentially
let blinding: Vec<(C::Scalar, C::Scalar)> = (usable_rows..params.n() as usize)
.into_iter()
.map(|_| (C::Scalar::random(&mut rng), C::Scalar::random(&mut rng)))
.collect();
Comment on lines +488 to +492
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm We can maybe file an issue in case we see this being critical.

let (permuted_input_expression, permuted_table_coeffs): (Vec<_>, Vec<_>) = input_unique_ranges
.into_par_iter()
.enumerate()
.flat_map(|(i, (coeff, range))| {
// subtract off the number of rows in table rows that correspond to input uniques
let leftover_range_start = range.start - i;
let leftover_range_end = range.end - i - 1;
[(coeff, coeff)].into_par_iter().chain(
leftover_table_coeffs[leftover_range_start..leftover_range_end]
.par_iter()
.map(move |leftover_table_coeff| (coeff, *leftover_table_coeff)),
)
})
.chain(blinding)
.unzip();

assert_eq!(permuted_input_expression.len(), params.n() as usize);
assert_eq!(permuted_table_coeffs.len(), params.n() as usize);

Ok((
domain.lagrange_from_vec(permuted_input_expression),
domain.lagrange_from_vec(permuted_table_coeffs),
))
}

#[allow(dead_code)]
fn permute_expression_pair_seq<'params, C: CurveAffine, P: Params<'params, C>, R: RngCore>(
pk: &ProvingKey<C>,
params: &P,
domain: &EvaluationDomain<C::Scalar>,
Expand All @@ -404,7 +527,7 @@ fn permute_expression_pair<'params, C: CurveAffine, P: Params<'params, C>, R: Rn
permuted_input_expression.truncate(usable_rows);

// Sort input lookup expression values
permuted_input_expression.sort();
permuted_input_expression.par_sort();

// A BTreeMap of each unique element in the table expression and its count
let mut leftover_table_map: BTreeMap<C::Scalar, u32> = table_expression
Expand Down
16 changes: 15 additions & 1 deletion halo2_proofs/src/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use group::ff::{BatchInvert, Field};
use std::fmt::Debug;
use std::io;
use std::marker::PhantomData;
use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Mul, RangeFrom, RangeFull, Sub};
use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Mul, Range, RangeFrom, RangeFull, Sub};

/// Generic commitment scheme structures
pub mod commitment;
Expand Down Expand Up @@ -112,6 +112,20 @@ impl<F, B> IndexMut<RangeFull> for Polynomial<F, B> {
}
}

impl<F, B> Index<Range<usize>> for Polynomial<F, B> {
type Output = [F];

fn index(&self, index: Range<usize>) -> &[F] {
self.values.index(index)
}
}

impl<F, B> IndexMut<Range<usize>> for Polynomial<F, B> {
fn index_mut(&mut self, index: Range<usize>) -> &mut [F] {
self.values.index_mut(index)
}
}

impl<F, B> Deref for Polynomial<F, B> {
type Target = [F];

Expand Down