Skip to content

Commit

Permalink
Add prove verification error. (#498)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/498)
<!-- Reviewable:end -->
  • Loading branch information
alonh5 authored Mar 20, 2024
1 parent 8839cd3 commit 609d2c1
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 42 deletions.
60 changes: 31 additions & 29 deletions src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ impl<H: Hasher<NativeType = u8>> FriVerifier<H> {
config: FriConfig,
proof: FriProof<H>,
column_bounds: Vec<CirclePolyDegreeBound>,
) -> Result<Self, VerificationError> {
) -> Result<Self, FriVerificationError> {
assert!(column_bounds.is_sorted_by_key(|b| Reverse(*b)));

let max_column_bound = column_bounds[0];
Expand Down Expand Up @@ -359,19 +359,19 @@ impl<H: Hasher<NativeType = u8>> FriVerifier<H> {

layer_bound = layer_bound
.fold(FOLD_STEP)
.ok_or(VerificationError::InvalidNumFriLayers)?;
.ok_or(FriVerificationError::InvalidNumFriLayers)?;
layer_domain = layer_domain.double();
}

if layer_bound.log_degree_bound != config.log_last_layer_degree_bound {
return Err(VerificationError::InvalidNumFriLayers);
return Err(FriVerificationError::InvalidNumFriLayers);
}

let last_layer_domain = layer_domain;
let last_layer_poly = proof.last_layer_poly;

if last_layer_poly.len() > (1 << config.log_last_layer_degree_bound) {
return Err(VerificationError::LastLayerDegreeInvalid);
return Err(FriVerificationError::LastLayerDegreeInvalid);
}

channel.mix_felts(&last_layer_poly);
Expand Down Expand Up @@ -402,7 +402,7 @@ impl<H: Hasher<NativeType = u8>> FriVerifier<H> {
pub fn decommit<F>(
mut self,
decommitted_values: Vec<SparseCircleEvaluation<F>>,
) -> Result<(), VerificationError>
) -> Result<(), FriVerificationError>
where
F: ExtensionOf<BaseField>,
SecureField: ExtensionOf<F>,
Expand All @@ -415,7 +415,7 @@ impl<H: Hasher<NativeType = u8>> FriVerifier<H> {
self,
queries: &Queries,
decommitted_values: Vec<SparseCircleEvaluation<F>>,
) -> Result<(), VerificationError>
) -> Result<(), FriVerificationError>
where
F: ExtensionOf<BaseField>,
SecureField: ExtensionOf<F>,
Expand All @@ -436,7 +436,7 @@ impl<H: Hasher<NativeType = u8>> FriVerifier<H> {
&self,
queries: &Queries,
decommitted_values: Vec<SparseCircleEvaluation<F>>,
) -> Result<(Queries, Vec<SecureField>), VerificationError>
) -> Result<(Queries, Vec<SecureField>), FriVerificationError>
where
F: ExtensionOf<BaseField>,
SecureField: ExtensionOf<F> + Field,
Expand Down Expand Up @@ -480,7 +480,7 @@ impl<H: Hasher<NativeType = u8>> FriVerifier<H> {
self,
queries: Queries,
query_evals: Vec<SecureField>,
) -> Result<(), VerificationError> {
) -> Result<(), FriVerificationError> {
let Self {
last_layer_domain: domain,
last_layer_poly,
Expand All @@ -491,7 +491,7 @@ impl<H: Hasher<NativeType = u8>> FriVerifier<H> {
let x = domain.at(bit_reverse_index(query, domain.log_size()));

if query_eval != last_layer_poly.eval_at_point(x.into()) {
return Err(VerificationError::LastLayerEvaluationsInvalid);
return Err(FriVerificationError::LastLayerEvaluationsInvalid);
}
}

Expand Down Expand Up @@ -556,8 +556,8 @@ pub trait FriChannel {
fn draw(&mut self) -> Self::Field;
}

#[derive(Error, Debug)]
pub enum VerificationError {
#[derive(Clone, Copy, Debug, Error)]
pub enum FriVerificationError {
#[error("proof contains an invalid number of FRI layers")]
InvalidNumFriLayers,
#[error("queries do not resolve to their commitment in layer {layer}")]
Expand Down Expand Up @@ -668,7 +668,7 @@ impl<H: Hasher<NativeType = u8>> FriLayerVerifier<H> {
&self,
queries: Queries,
evals_at_queries: Vec<SecureField>,
) -> Result<(Queries, Vec<SecureField>), VerificationError> {
) -> Result<(Queries, Vec<SecureField>), FriVerificationError> {
let decommitment = &self.proof.decommitment;
let commitment = self.proof.commitment;

Expand All @@ -684,7 +684,7 @@ impl<H: Hasher<NativeType = u8>> FriLayerVerifier<H> {
if let [eval] = *leaf {
expected_decommitment_evals.push(eval);
} else {
return Err(VerificationError::InnerLayerCommitmentInvalid {
return Err(FriVerificationError::InnerLayerCommitmentInvalid {
layer: self.layer_index,
});
}
Expand All @@ -696,7 +696,7 @@ impl<H: Hasher<NativeType = u8>> FriLayerVerifier<H> {
.flat_map(|e| &e.values);

if !actual_decommitment_evals.eq(&expected_decommitment_evals) {
return Err(VerificationError::InnerLayerCommitmentInvalid {
return Err(FriVerificationError::InnerLayerCommitmentInvalid {
layer: self.layer_index,
});
}
Expand All @@ -715,7 +715,7 @@ impl<H: Hasher<NativeType = u8>> FriLayerVerifier<H> {
.collect::<Vec<usize>>();

if !decommitment.verify(commitment, &decommitment_positions) {
return Err(VerificationError::InnerLayerCommitmentInvalid {
return Err(FriVerificationError::InnerLayerCommitmentInvalid {
layer: self.layer_index,
});
}
Expand All @@ -738,7 +738,7 @@ impl<H: Hasher<NativeType = u8>> FriLayerVerifier<H> {
&self,
queries: &Queries,
evals_at_queries: &[SecureField],
) -> Result<SparseLineEvaluation, VerificationError> {
) -> Result<SparseLineEvaluation, FriVerificationError> {
// Evals provided by the verifier.
let mut evals_at_queries = evals_at_queries.iter().copied();

Expand All @@ -760,7 +760,7 @@ impl<H: Hasher<NativeType = u8>> FriLayerVerifier<H> {
let eval = match subline_queries.next_if_eq(&&eval_position) {
Some(_) => evals_at_queries.next().unwrap(),
None => proof_evals.next().ok_or(
VerificationError::InnerLayerEvaluationsInvalid {
FriVerificationError::InnerLayerEvaluationsInvalid {
layer: self.layer_index,
},
)?,
Expand All @@ -780,7 +780,7 @@ impl<H: Hasher<NativeType = u8>> FriLayerVerifier<H> {

// Check all proof evals have been consumed.
if !proof_evals.is_empty() {
return Err(VerificationError::InnerLayerEvaluationsInvalid {
return Err(FriVerificationError::InnerLayerEvaluationsInvalid {
layer: self.layer_index,
});
}
Expand Down Expand Up @@ -912,7 +912,7 @@ mod tests {

use num_traits::{One, Zero};

use super::{get_opening_positions, SparseCircleEvaluation, VerificationError};
use super::{get_opening_positions, FriVerificationError, SparseCircleEvaluation};
use crate::commitment_scheme::blake2_hash::Blake2sHasher;
use crate::core::backend::cpu::{CPUCircleEvaluation, CPUCirclePoly, CPULineEvaluation};
use crate::core::backend::CPUBackend;
Expand Down Expand Up @@ -1003,7 +1003,7 @@ mod tests {
}

#[test]
fn valid_proof_passes_verification() -> Result<(), VerificationError> {
fn valid_proof_passes_verification() -> Result<(), FriVerificationError> {
const LOG_DEGREE: u32 = 3;
let polynomial = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR);
let log_domain_size = polynomial.domain.log_size();
Expand All @@ -1019,7 +1019,8 @@ mod tests {
}

#[test]
fn valid_proof_with_constant_last_layer_passes_verification() -> Result<(), VerificationError> {
fn valid_proof_with_constant_last_layer_passes_verification() -> Result<(), FriVerificationError>
{
const LOG_DEGREE: u32 = 3;
const LAST_LAYER_LOG_BOUND: u32 = 0;
let polynomial = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR);
Expand All @@ -1036,7 +1037,7 @@ mod tests {
}

#[test]
fn valid_mixed_degree_proof_passes_verification() -> Result<(), VerificationError> {
fn valid_mixed_degree_proof_passes_verification() -> Result<(), FriVerificationError> {
const LOG_DEGREES: [u32; 3] = [6, 5, 4];
let polynomials = LOG_DEGREES.map(|log_d| polynomial_evaluation(log_d, LOG_BLOWUP_FACTOR));
let log_domain_size = polynomials[0].domain.log_size();
Expand All @@ -1052,7 +1053,8 @@ mod tests {
}

#[test]
fn valid_mixed_degree_end_to_end_proof_passes_verification() -> Result<(), VerificationError> {
fn valid_mixed_degree_end_to_end_proof_passes_verification() -> Result<(), FriVerificationError>
{
const LOG_DEGREES: [u32; 3] = [6, 5, 4];
let polynomials = LOG_DEGREES.map(|log_d| polynomial_evaluation(log_d, LOG_BLOWUP_FACTOR));
let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, 3);
Expand Down Expand Up @@ -1088,7 +1090,7 @@ mod tests {

assert!(matches!(
verifier,
Err(VerificationError::InvalidNumFriLayers)
Err(FriVerificationError::InvalidNumFriLayers)
));
}

Expand All @@ -1110,7 +1112,7 @@ mod tests {

assert!(matches!(
verifier,
Err(VerificationError::InvalidNumFriLayers)
Err(FriVerificationError::InvalidNumFriLayers)
));
}

Expand All @@ -1133,7 +1135,7 @@ mod tests {

assert!(matches!(
verification_result,
Err(VerificationError::InnerLayerEvaluationsInvalid { layer: 1 })
Err(FriVerificationError::InnerLayerEvaluationsInvalid { layer: 1 })
));
}

Expand All @@ -1156,7 +1158,7 @@ mod tests {

assert!(matches!(
verification_result,
Err(VerificationError::InnerLayerCommitmentInvalid { layer: 1 })
Err(FriVerificationError::InnerLayerCommitmentInvalid { layer: 1 })
));
}

Expand All @@ -1178,7 +1180,7 @@ mod tests {

assert!(matches!(
verifier,
Err(VerificationError::LastLayerDegreeInvalid)
Err(FriVerificationError::LastLayerDegreeInvalid)
));
}

Expand All @@ -1201,7 +1203,7 @@ mod tests {

assert!(matches!(
verification_result,
Err(VerificationError::LastLayerEvaluationsInvalid)
Err(FriVerificationError::LastLayerEvaluationsInvalid)
));
}

Expand Down
20 changes: 14 additions & 6 deletions src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::iter::zip;
use itertools::{enumerate, Itertools};
use thiserror::Error;

use super::fri::FriVerificationError;
use super::poly::circle::{CanonicCoset, MAX_CIRCLE_DOMAIN_LOG_SIZE};
use super::queries::SparseSubCircleDomain;
use super::ColumnVec;
Expand Down Expand Up @@ -158,7 +159,11 @@ pub fn prove(
})
}

pub fn verify(proof: StarkProof, air: &impl Air<CPUBackend>, channel: &mut Channel) -> bool {
pub fn verify(
proof: StarkProof,
air: &impl Air<CPUBackend>,
channel: &mut Channel,
) -> Result<(), VerificationError> {
// Read trace commitment.
let mut commitment_scheme = CommitmentSchemeVerifier::new();
commitment_scheme.commit(proof.commitments[0], channel);
Expand All @@ -185,8 +190,7 @@ pub fn verify(proof: StarkProof, air: &impl Air<CPUBackend>, channel: &mut Chann

let bounds = air.quotient_log_bounds();
let fri_config = FriConfig::new(LOG_LAST_LAYER_DEGREE_BOUND, LOG_BLOWUP_FACTOR, N_QUERIES);
let mut fri_verifier =
FriVerifier::commit(channel, fri_config, proof.fri_proof, bounds).unwrap();
let mut fri_verifier = FriVerifier::commit(channel, fri_config, proof.fri_proof, bounds)?;

ProofOfWork::new(PROOF_OF_WORK_BITS).verify(channel, &proof.proof_of_work);
let opening_positions = fri_verifier
Expand All @@ -210,9 +214,7 @@ pub fn verify(proof: StarkProof, air: &impl Air<CPUBackend>, channel: &mut Chann
oods_point,
);

fri_verifier.decommit(sparse_circle_evaluations).unwrap();

true
Ok(fri_verifier.decommit(sparse_circle_evaluations)?)
}

fn prepare_fri_evaluations(
Expand Down Expand Up @@ -299,6 +301,12 @@ pub enum ProvingError {
ConstraintsNotSatisfied,
}

#[derive(Clone, Copy, Debug, Error)]
pub enum VerificationError {
#[error(transparent)]
Fri(#[from] FriVerificationError),
}

#[cfg(test)]
mod tests {
use num_traits::Zero;
Expand Down
17 changes: 10 additions & 7 deletions src/fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::fields::{FieldExpOps, IntoSlice};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, verify, ProvingError, StarkProof};
use crate::core::prover::{prove, verify, ProvingError, StarkProof, VerificationError};

pub mod air;
mod component;
Expand Down Expand Up @@ -57,7 +57,10 @@ impl Fibonacci {
}
}

pub fn verify_proof<const N_BITS: u32>(proof: StarkProof, claim: BaseField) -> bool {
pub fn verify_proof<const N_BITS: u32>(
proof: StarkProof,
claim: BaseField,
) -> Result<(), VerificationError> {
let fib = Fibonacci::new(N_BITS, claim);
let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[claim])));
verify(proof, &fib.air, channel)
Expand Down Expand Up @@ -211,7 +214,7 @@ mod tests {
.composition_polynomial_oods_value,
hz
);
assert!(verify_proof::<FIB_LOG_SIZE>(proof, fib.claim));
verify_proof::<FIB_LOG_SIZE>(proof, fib.claim).unwrap();
}

// TODO(AlonH): Check the correct error occurs after introducing errors instead of
Expand All @@ -225,7 +228,7 @@ mod tests {
let mut invalid_proof = fib.prove().unwrap();
invalid_proof.opened_values.0[0][0][4] += BaseField::one();

verify_proof::<FIB_LOG_SIZE>(invalid_proof, fib.claim);
verify_proof::<FIB_LOG_SIZE>(invalid_proof, fib.claim).unwrap();
}

// TODO(AlonH): Check the correct error occurs after introducing errors instead of
Expand All @@ -239,7 +242,7 @@ mod tests {
let mut invalid_proof = fib.prove().unwrap();
invalid_proof.trace_oods_values.swap(0, 1);

verify_proof::<FIB_LOG_SIZE>(invalid_proof, fib.claim);
verify_proof::<FIB_LOG_SIZE>(invalid_proof, fib.claim).unwrap();
}

// TODO(AlonH): Check the correct error occurs after introducing errors instead of
Expand All @@ -253,7 +256,7 @@ mod tests {
let mut invalid_proof = fib.prove().unwrap();
invalid_proof.opened_values.0[0][0].pop();

verify_proof::<FIB_LOG_SIZE>(invalid_proof, fib.claim);
verify_proof::<FIB_LOG_SIZE>(invalid_proof, fib.claim).unwrap();
}

#[test]
Expand All @@ -267,6 +270,6 @@ mod tests {
let proof = prove(&air, prover_channel, trace).unwrap();
let verifier_channel =
&mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[fib.claim])));
assert!(verify(proof, &air, verifier_channel));
verify(proof, &air, verifier_channel).unwrap();
}
}

0 comments on commit 609d2c1

Please sign in to comment.