Skip to content

Commit

Permalink
Improve error handling for core functions
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkingLee committed Jul 27, 2023
1 parent fc25554 commit 0868497
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 50 deletions.
2 changes: 1 addition & 1 deletion crates/melo-erasure-coding/src/erasure_coding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn extend(fs: &FsFFTSettings, source: &[BlsScalar]) -> Result<Vec<BlsScalar>
pub fn extend_poly(fs: &FsFFTSettings, poly: &Polynomial) -> Result<Vec<BlsScalar>, String> {
let mut coeffs = poly.0.coeffs.clone();
coeffs.resize(coeffs.len() * 2, FsFr::zero());
let mut extended_coeffs_fft = fs.fft_fr(&coeffs, false).unwrap();
let mut extended_coeffs_fft = fs.fft_fr(&coeffs, false)?;
reverse_bit_order(&mut extended_coeffs_fft);
Ok(BlsScalar::vec_from_repr(extended_coeffs_fft))
}
Expand Down
88 changes: 48 additions & 40 deletions crates/melo-erasure-coding/src/extend_col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,60 @@ use rust_kzg_blst::types::fft_settings::FsFFTSettings;
use crate::erasure_coding::{extend, extend_fs_g1};

pub fn extend_segments_col(
fs: &FsFFTSettings,
segments: &Vec<Segment>,
fs: &FsFFTSettings,
segments: &Vec<Segment>,
) -> Result<Vec<Segment>, String> {
let k = segments.len();
let x = segments[0].position.x;
let segment_size = segments[0].size();
let k = segments.len();
let x = segments[0].position.x;
let segment_size = segments[0].size();

let mut proofs = vec![];
let sorted_rows: Vec<BlsScalar> = segments
.iter()
.sorted_by_key(|s| s.position.y)
.enumerate()
.filter(|(i, s)| s.position.x == x && *i == s.position.y as usize && s.size() == segment_size)
.flat_map(|(_, s)| {
proofs.push(s.content.proof.clone());
s.content.data.clone()
})
.collect();
if segments.iter().any(|s| s.position.x != x) {
return Err("segments are not from the same column".to_string());
}

if sorted_rows.len() != k * segment_size {
return Err("segments x not equal".to_string());
}
if !k.is_power_of_two() || !segment_size.is_power_of_two() {
return Err("number of segments and segment size must be powers of two".to_string());
}

let extended_proofs = extend_fs_g1(fs, &proofs)?;
let mut proofs = vec![];
let sorted_rows: Vec<BlsScalar> = segments
.iter()
.sorted_by_key(|s| s.position.y)
.enumerate()
.filter(|(i, s)| s.position.x == x && *i == s.position.y as usize && s.size() == segment_size)
.flat_map(|(_, s)| {
proofs.push(s.content.proof.clone());
s.content.data.clone()
})
.collect();

let mut extended_cols = vec![];
if sorted_rows.len() != k * segment_size {
return Err("mismatch in the number of elements after sorting".to_string());
}

for i in 0..(segment_size) {
let col: Vec<BlsScalar> = sorted_rows
.iter()
.skip(i)
.step_by(segment_size)
.map(|s| s.clone())
.collect::<Vec<BlsScalar>>();
extended_cols.push(extend(fs, &col)?);
}
let extended_proofs = extend_fs_g1(fs, &proofs)?;

let mut extended_segments = vec![];
let mut extended_cols = vec![];

// 需要获取奇数部分
extended_proofs.iter().skip(1).step_by(2).enumerate().for_each(|(i, proof)| {
let position = melo_core_primitives::kzg::Position { x, y: (i + k) as u32 };
let data = extended_cols.iter().map(|col| col[i]).collect::<Vec<BlsScalar>>();
let segment = Segment { position, content: SegmentData { data, proof: proof.clone() } };
extended_segments.push(segment);
});
for i in 0..(segment_size) {
let col: Vec<BlsScalar> = sorted_rows
.iter()
.skip(i)
.step_by(segment_size)
.map(|s| s.clone())
.collect::<Vec<BlsScalar>>();
extended_cols.push(extend(fs, &col)?);
}

Ok(extended_segments)
}
let mut extended_segments = vec![];

// Need to obtain odd parts
extended_proofs.iter().skip(1).step_by(2).enumerate().for_each(|(i, proof)| {
let position = melo_core_primitives::kzg::Position { x, y: (i + k) as u32 };
let data = extended_cols.iter().map(|col| col[i]).collect::<Vec<BlsScalar>>();
let segment = Segment { position, content: SegmentData { data, proof: proof.clone() } };
extended_segments.push(segment);
});

Ok(extended_segments)
}
11 changes: 11 additions & 0 deletions crates/melo-erasure-coding/src/recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ pub fn recovery_row_from_segments(
) -> Result<Vec<Segment>, String> {
let y = segments[0].position.y;
let segments_size = segments[0].size();

if segments.iter().any(|s| s.position.y != y) {
return Err("segments are not from the same row".to_string());
}
if !segments_size.is_power_of_two() || !chunk_count.is_power_of_two() {
return Err("segment size and chunk_count must be a power of two".to_string());
}
if segments.iter().any(|s| s.size() != segments_size) {
return Err("segments are not of the same size".to_string());
}

let order_segments = order_segments_row(&segments, chunk_count)?;
let mut row = segment_datas_to_row(&order_segments, segments_size);
reverse_bit_order(&mut row);
Expand Down
8 changes: 7 additions & 1 deletion crates/melo-erasure-coding/src/segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@ pub fn segment_datas_to_row(segments: &Vec<Option<SegmentData>>, chunk_size: usi
///
/// A `Result` containing a vector of `Segment` structs or an error message.
pub fn poly_to_segment_vec(poly: &Polynomial, kzg: &KZG, y: usize, chunk_size: usize) -> Result<Vec<Segment>, String> {
let poly_len = poly.0.coeffs.len();
let poly_len = poly.checked()?.0.coeffs.len();

// chunk_size must be a power of two
if !chunk_size.is_power_of_two() {
return Err("chunk_size must be a power of two".to_string());
}

let fk = FsFK20MultiSettings::new(&kzg.ks, 2 * poly_len, chunk_size).unwrap();
let all_proofs = fk.data_availability(&poly.0).unwrap();
let extended_poly = extend_poly(&fk.kzg_settings.fs, &poly)?;
Expand Down
109 changes: 103 additions & 6 deletions crates/melo-erasure-coding/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ fn order_segments_col_test() {
assert!(col.is_err());
}


#[test]
fn poly_to_segment_vec_test() {
// Build a random polynomial
Expand Down Expand Up @@ -467,6 +466,14 @@ fn extend_poly_test() {
let extended_poly = extend_poly(kzg.get_fs(), &poly).unwrap();
assert_eq!(extended_poly.len(), 32);

let poly_err = random_poly(3);
let extended_poly_err = extend_poly(kzg.get_fs(), &poly_err);
assert!(extended_poly_err.is_err());

let poly_err = random_poly(6);
let extended_poly_err = extend_poly(kzg.get_fs(), &poly_err);
assert!(extended_poly_err.is_err());

let random_positions = random_vec(num_shards * 2);
let mut cells = [None; 32];
for i in 0..num_shards {
Expand Down Expand Up @@ -506,20 +513,49 @@ fn recovery_row_from_segments_test() {

// Recover segments
let recovered_segments = recovery_row_from_segments(&random_segments, &kzg, chunk_count).unwrap();
assert_eq!(recovered_segments[0], segments[0]);

// Verify if the recovered segments are the same as the original segments
for i in 0..chunk_count {
assert_eq!(recovered_segments[i], segments[i]);
}

// Remove one segment from random_segments
random_segments.remove(0);
let mut segments_err = random_segments.clone();
segments_err.remove(0);
// Recover segments, it should fail due to an incorrect number of segments
let recovered_segments = recovery_row_from_segments(&random_segments, &kzg, chunk_count);
let result = recovery_row_from_segments(&segments_err, &kzg, chunk_count);

// Verify if it fails
assert!(result.is_err());

// Modify one y value in random_segments
let mut segments_err = random_segments.clone();
segments_err[0].position.y = 3;
// Recover segments, it should fail due to incorrect x values
let result = recovery_row_from_segments(&segments_err, &kzg, chunk_count);
// Verify if it fails
assert!(result.is_err());

// segment size and chunk_count must be a power of two
let result = recovery_row_from_segments(&segments_err, &kzg, chunk_count + 1);
assert!(result.is_err());

// remove one of the segment.data
let mut segments_err = random_segments.clone();
segments_err[0].content.data.remove(0);
// Recover segments, it should fail due to incorrect segment.data length
let result = recovery_row_from_segments(&segments_err, &kzg, chunk_count);
// Verify if it fails
assert!(result.is_err());

// segments is not enough
let mut segments_err = random_segments.clone();
segments_err.remove(0);
// Recover segments, it should fail due to incorrect segment.data length
let result = recovery_row_from_segments(&segments_err, &kzg, chunk_count);
// Verify if it fails
assert!(recovered_segments.is_err());
assert!(result.is_err());

}

#[test]
Expand Down Expand Up @@ -617,6 +653,27 @@ fn extend_and_commit_multi_test() {
}
}

fn extend_returns_err_case(
num_shards: usize,
) {
let kzg = KZG::new(embedded_kzg_settings());

let evens = (0..num_shards)
.map(|_| rand::random::<[u8; 31]>())
.map(BlsScalar::from)
.collect::<Vec<_>>();

let result = extend(&kzg.get_fs(), &evens);
assert!(result.is_err());
}

#[test]
fn extend_returns_err_test() {
extend_returns_err_case(5);
extend_returns_err_case(0);
extend_returns_err_case(321);
}

#[test]
fn extend_fs_g1_test() {
let kzg = KZG::new(embedded_kzg_settings());
Expand All @@ -626,7 +683,14 @@ fn extend_fs_g1_test() {
}
let extended_commits = extend_fs_g1(kzg.get_fs(), &commits).unwrap();
assert!(extended_commits.len() == 8);
assert!(extended_commits[2].0 == commits[1].0);

for i in 0..4 {
assert_eq!(extended_commits[i * 2], commits[i]);
}

commits.push(KZGCommitment(FsG1::rand()));
let result = extend_fs_g1(kzg.get_fs(), &commits);
assert!(result.is_err());
}

#[test]
Expand Down Expand Up @@ -659,6 +723,39 @@ fn extend_segments_col_test() {
let pick_s = extended_col[i].clone();
assert!(pick_s.verify(&kzg, &extended_commitments[i * 2 + 1], chunk_count).unwrap());
}

// Modify a single x value in the column
let mut modified_col = extended_col.clone();
modified_col[0].position.x = 3;

// Extend the column, it should fail due to incorrect x values
let extended_col_err = extend_segments_col(kzg.get_fs(), &modified_col);
assert!(extended_col_err.is_err());

// Add 3 random segments to the column
for _ in 0..3 {
let data = (0..chunk_len)
.map(|_| rand::random::<[u8; 31]>())
.map(BlsScalar::from)
.collect::<Vec<_>>();
let proof = KZGProof(FsG1::rand());
let segment_data = SegmentData { data, proof };
let position = Position { x: 0, y: 0 };
modified_col.push(Segment { position, content: segment_data });
}

// Extend the column, it should fail due to an incorrect number of segments
let extended_col_err = extend_segments_col(kzg.get_fs(), &modified_col);
assert!(extended_col_err.is_err());

// Modify a single y value in the column
let mut extended_col_err = extended_col.clone();
extended_col_err[0].position.y = 3;

// Extend the column, it should fail due to incorrect y values
let extended_col = extend_segments_col(kzg.get_fs(), &modified_col);
assert!(extended_col.is_err());

}

#[test]
Expand Down
7 changes: 7 additions & 0 deletions primitives/src/kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,13 @@ impl Polynomial {
FsPoly::new(size).map(Self)
}

pub fn checked(&self) -> Result<Self, String> {
if !self.0.coeffs.len().is_power_of_two() {
return Err("Polynomial size must be a power of two".to_string());
}
Ok(self.clone())
}

pub fn from_coeffs(coeffs: &[FsFr]) -> Self {
Polynomial(FsPoly { coeffs: coeffs.to_vec() })
}
Expand Down
13 changes: 11 additions & 2 deletions primitives/src/segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ pub struct SegmentData {

impl SegmentData {
pub fn new(proof: KZGProof, size: usize) -> Self {
// TODO: check data length
let arr = vec![BlsScalar::default(); size];
Self { data: arr, proof }
}
Expand All @@ -44,14 +43,24 @@ impl SegmentData {
self.data.len()
}

pub fn checked(&self) -> Result<Self, String> {
if self.data.len() == 0 {
return Err("segment data is empty".to_string());
}
// data.len() is a power of two
if !self.data.len().is_power_of_two() {
return Err("segment data length is not a power of two".to_string());
}
Ok(self.clone())
}

pub fn from_data(
positon: &Position,
content: &[BlsScalar],
kzg: &KZG,
poly: &Polynomial,
chunk_count: usize,
) -> Result<Self, String> {
// let i = kzg.get_kzg_index(chunk_count, positon.x as usize, content.len());
kzg.compute_proof_multi(poly, positon.x as usize, chunk_count, content.len())
.map(|p| Ok(Self { data: content.to_vec(), proof: p }))?
}
Expand Down

0 comments on commit 0868497

Please sign in to comment.