Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into 0xaatif/refactor-t…
Browse files Browse the repository at this point in the history
…race-decoder-decoding
  • Loading branch information
0xaatif committed Aug 12, 2024
2 parents 74d3127 + c0e0351 commit 2bd9d32
Show file tree
Hide file tree
Showing 11 changed files with 669 additions and 250 deletions.
132 changes: 81 additions & 51 deletions evm_arithmetization/src/byte_packing/byte_packing_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
//! 256^length, and as a result a different value will be stored in memory.

use core::marker::PhantomData;
use std::borrow::Borrow;

use itertools::Itertools;
use plonky2::field::extension::{Extendable, FieldExtension};
Expand All @@ -46,10 +47,7 @@ use starky::stark::Stark;

use super::NUM_BYTES;
use crate::all_stark::EvmStarkFrame;
use crate::byte_packing::columns::{
index_len, value_bytes, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, IS_READ, LEN_INDICES_COLS,
NUM_COLUMNS, RANGE_COUNTER, RC_FREQUENCIES, TIMESTAMP,
};
use crate::byte_packing::columns::*;
use crate::witness::memory::MemoryAddress;

/// Strict upper bound for the individual bytes range-check.
Expand All @@ -65,7 +63,8 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
// obtain the corresponding limb.
let outputs: Vec<Column<F>> = (0..8)
.map(|i| {
let range = value_bytes(i * 4)..value_bytes(i * 4) + 4;
let range = BYTE_PACKING_COL_MAP.value_bytes[i * 4]
..BYTE_PACKING_COL_MAP.value_bytes[i * 4] + 4;
Column::linear_combination(
range
.enumerate()
Expand All @@ -74,34 +73,54 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
})
.collect();

let sequence_len: Column<F> = Column::linear_combination(
(0..NUM_BYTES).map(|i| (index_len(i), F::from_canonical_usize(i + 1))),
);

Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL])
.chain([sequence_len])
.chain(Column::singles(&[TIMESTAMP]))
.chain(outputs)
.collect()
let sequence_len: Column<F> = Column::linear_combination((0..NUM_BYTES).map(|i| {
(
BYTE_PACKING_COL_MAP.index_len[i],
F::from_canonical_usize(i + 1),
)
}));

Column::singles([
BYTE_PACKING_COL_MAP.is_read,
BYTE_PACKING_COL_MAP.addr_context,
BYTE_PACKING_COL_MAP.addr_segment,
BYTE_PACKING_COL_MAP.addr_virtual,
])
.chain([sequence_len])
.chain(Column::singles(&[BYTE_PACKING_COL_MAP.timestamp]))
.chain(outputs)
.collect()
}

/// CTL filter for the `BytePackingStark` looked table.
pub(crate) fn ctl_looked_filter<F: Field>() -> Filter<F> {
// The CPU table is only interested in our sequence end rows,
// since those contain the final limbs of our packed int.
Filter::new_simple(Column::sum((0..NUM_BYTES).map(index_len)))
Filter::new_simple(Column::sum(
(0..NUM_BYTES).map(|i| BYTE_PACKING_COL_MAP.index_len[i]),
))
}

/// Column linear combination for the `BytePackingStark` table reading/writing
/// the `i`th byte sequence from `MemoryStark`.
pub(crate) fn ctl_looking_memory<F: Field>(i: usize) -> Vec<Column<F>> {
let mut res = Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT]).collect_vec();
let mut res = Column::singles([
BYTE_PACKING_COL_MAP.is_read,
BYTE_PACKING_COL_MAP.addr_context,
BYTE_PACKING_COL_MAP.addr_segment,
])
.collect_vec();

// Compute the virtual address: `ADDR_VIRTUAL` + `sequence_len` - 1 - i.
let sequence_len_minus_one = (0..NUM_BYTES)
.map(|j| (index_len(j), F::from_canonical_usize(j)))
.map(|j| {
(
BYTE_PACKING_COL_MAP.index_len[j],
F::from_canonical_usize(j),
)
})
.collect::<Vec<_>>();
let mut addr_virt_cols = vec![(ADDR_VIRTUAL, F::ONE)];
let mut addr_virt_cols = vec![(BYTE_PACKING_COL_MAP.addr_virtual, F::ONE)];
addr_virt_cols.extend(sequence_len_minus_one);
let addr_virt = Column::linear_combination_with_constant(
addr_virt_cols,
Expand All @@ -111,20 +130,22 @@ pub(crate) fn ctl_looking_memory<F: Field>(i: usize) -> Vec<Column<F>> {
res.push(addr_virt);

// The i'th input byte being read/written.
res.push(Column::single(value_bytes(i)));
res.push(Column::single(BYTE_PACKING_COL_MAP.value_bytes[i]));

// Since we're reading a single byte, the higher limbs must be zero.
res.extend((1..8).map(|_| Column::zero()));

res.push(Column::single(TIMESTAMP));
res.push(Column::single(BYTE_PACKING_COL_MAP.timestamp));

res
}

/// CTL filter for reading/writing the `i`th byte of the byte sequence from/to
/// memory.
pub(crate) fn ctl_looking_memory_filter<F: Field>(i: usize) -> Filter<F> {
Filter::new_simple(Column::sum((i..NUM_BYTES).map(index_len)))
Filter::new_simple(Column::sum(
(i..NUM_BYTES).map(|i| BYTE_PACKING_COL_MAP.index_len[i]),
))
}

/// Information about a byte packing operation needed for witness generation.
Expand Down Expand Up @@ -205,24 +226,24 @@ impl<F: RichField + Extendable<D>, const D: usize> BytePackingStark<F, D> {
virt,
} = base_address;

let mut row = [F::ZERO; NUM_COLUMNS];
row[IS_READ] = F::from_bool(is_read);
let mut row = BytePackingColumnsView::default();
row.is_read = F::from_bool(is_read);

row[ADDR_CONTEXT] = F::from_canonical_usize(context);
row[ADDR_SEGMENT] = F::from_canonical_usize(segment);
row.addr_context = F::from_canonical_usize(context);
row.addr_segment = F::from_canonical_usize(segment);
// We store the initial virtual segment. But the CTLs,
// we start with virt + sequence_len - 1.
row[ADDR_VIRTUAL] = F::from_canonical_usize(virt);
row.addr_virtual = F::from_canonical_usize(virt);

row[TIMESTAMP] = F::from_canonical_usize(timestamp);
row.timestamp = F::from_canonical_usize(timestamp);

row[index_len(bytes.len() - 1)] = F::ONE;
row.index_len[bytes.len() - 1] = F::ONE;

for (i, &byte) in bytes.iter().rev().enumerate() {
row[value_bytes(i)] = F::from_canonical_u8(byte);
row.value_bytes[i] = F::from_canonical_u8(byte);
}

row
row.into()
}

const fn generate_padding_row(&self) -> [F; NUM_COLUMNS] {
Expand All @@ -237,26 +258,27 @@ impl<F: RichField + Extendable<D>, const D: usize> BytePackingStark<F, D> {
debug_assert!(cols.iter().all(|col| col.len() == n_rows));

for i in 0..BYTE_RANGE_MAX {
cols[RANGE_COUNTER][i] = F::from_canonical_usize(i);
cols[BYTE_PACKING_COL_MAP.range_counter][i] = F::from_canonical_usize(i);
}
for i in BYTE_RANGE_MAX..n_rows {
cols[RANGE_COUNTER][i] = F::from_canonical_usize(BYTE_RANGE_MAX - 1);
cols[BYTE_PACKING_COL_MAP.range_counter][i] =
F::from_canonical_usize(BYTE_RANGE_MAX - 1);
}

// For each column c in cols, generate the range-check
// permutations and put them in the corresponding range-check
// columns rc_c and rc_c+1.
for col in 0..NUM_BYTES {
for i in 0..n_rows {
let c = value_bytes(col);
let c = BYTE_PACKING_COL_MAP.value_bytes[col];
let x = cols[c][i].to_canonical_u64() as usize;
assert!(
x < BYTE_RANGE_MAX,
"column value {} exceeds the max range value {}",
x,
BYTE_RANGE_MAX
);
cols[RC_FREQUENCIES][x] += F::ONE;
cols[BYTE_PACKING_COL_MAP.rc_frequencies][x] += F::ONE;
}
}
}
Expand All @@ -279,13 +301,15 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt
P: PackedField<Scalar = FE>,
{
let local_values: &[P; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap();
let local_values: &BytePackingColumnsView<P> = local_values.borrow();
let next_values: &[P; NUM_COLUMNS] = vars.get_next_values().try_into().unwrap();
let next_values: &BytePackingColumnsView<P> = next_values.borrow();

// Check the range column: First value must be 0, last row
// must be 255, and intermediate rows must increment by 0
// or 1.
let rc1 = local_values[RANGE_COUNTER];
let rc2 = next_values[RANGE_COUNTER];
let rc1 = local_values.range_counter;
let rc2 = next_values.range_counter;
yield_constr.constraint_first_row(rc1);
let incr = rc2 - rc1;
yield_constr.constraint_transition(incr * incr - incr);
Expand All @@ -296,32 +320,32 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt

// We filter active columns by summing all the byte indices.
// Constraining each of them to be boolean is done later on below.
let current_filter = local_values[LEN_INDICES_COLS].iter().copied().sum::<P>();
let current_filter = local_values.index_len.iter().copied().sum::<P>();
yield_constr.constraint(current_filter * (current_filter - one));

// The filter column must start by one.
yield_constr.constraint_first_row(current_filter - one);

// The is_read flag must be boolean.
let current_is_read = local_values[IS_READ];
let current_is_read = local_values.is_read;
yield_constr.constraint(current_is_read * (current_is_read - one));

// Each byte index must be boolean.
for i in 0..NUM_BYTES {
let idx_i = local_values[index_len(i)];
let idx_i = local_values.index_len[i];
yield_constr.constraint(idx_i * (idx_i - one));
}

// Only padding rows have their filter turned off.
let next_filter = next_values[LEN_INDICES_COLS].iter().copied().sum::<P>();
let next_filter = next_values.index_len.iter().copied().sum::<P>();
yield_constr.constraint_transition(next_filter * (next_filter - current_filter));

// Check that all limbs after final length are 0.
for i in 0..NUM_BYTES - 1 {
// If the length is i+1, then value_bytes(i+1),...,value_bytes(NUM_BYTES-1) must
// be 0.
for j in i + 1..NUM_BYTES {
yield_constr.constraint(local_values[index_len(i)] * local_values[value_bytes(j)]);
yield_constr.constraint(local_values.index_len[i] * local_values.value_bytes[j]);
}
}
}
Expand All @@ -334,14 +358,16 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt
) {
let local_values: &[ExtensionTarget<D>; NUM_COLUMNS] =
vars.get_local_values().try_into().unwrap();
let local_values: &BytePackingColumnsView<ExtensionTarget<D>> = local_values.borrow();
let next_values: &[ExtensionTarget<D>; NUM_COLUMNS] =
vars.get_next_values().try_into().unwrap();
let next_values: &BytePackingColumnsView<ExtensionTarget<D>> = next_values.borrow();

// Check the range column: First value must be 0, last row
// must be 255, and intermediate rows must increment by 0
// or 1.
let rc1 = local_values[RANGE_COUNTER];
let rc2 = next_values[RANGE_COUNTER];
let rc1 = local_values.range_counter;
let rc2 = next_values.range_counter;
yield_constr.constraint_first_row(builder, rc1);
let incr = builder.sub_extension(rc2, rc1);
let t = builder.mul_sub_extension(incr, incr, incr);
Expand All @@ -353,7 +379,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt

// We filter active columns by summing all the byte indices.
// Constraining each of them to be boolean is done later on below.
let current_filter = builder.add_many_extension(&local_values[LEN_INDICES_COLS]);
let current_filter = builder.add_many_extension(local_values.index_len);
let constraint = builder.mul_sub_extension(current_filter, current_filter, current_filter);
yield_constr.constraint(builder, constraint);

Expand All @@ -362,20 +388,20 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt
yield_constr.constraint_first_row(builder, constraint);

// The is_read flag must be boolean.
let current_is_read = local_values[IS_READ];
let current_is_read = local_values.is_read;
let constraint =
builder.mul_sub_extension(current_is_read, current_is_read, current_is_read);
yield_constr.constraint(builder, constraint);

// Each byte index must be boolean.
for i in 0..NUM_BYTES {
let idx_i = local_values[index_len(i)];
let idx_i = local_values.index_len[i];
let constraint = builder.mul_sub_extension(idx_i, idx_i, idx_i);
yield_constr.constraint(builder, constraint);
}

// Only padding rows have their filter turned off.
let next_filter = builder.add_many_extension(&next_values[LEN_INDICES_COLS]);
let next_filter = builder.add_many_extension(next_values.index_len);
let constraint = builder.sub_extension(next_filter, current_filter);
let constraint = builder.mul_extension(next_filter, constraint);
yield_constr.constraint_transition(builder, constraint);
Expand All @@ -386,7 +412,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt
// be 0.
for j in i + 1..NUM_BYTES {
let constr =
builder.mul_extension(local_values[index_len(i)], local_values[value_bytes(j)]);
builder.mul_extension(local_values.index_len[i], local_values.value_bytes[j]);
yield_constr.constraint(builder, constr);
}
}
Expand All @@ -398,9 +424,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt

fn lookups(&self) -> Vec<Lookup<F>> {
vec![Lookup {
columns: Column::singles(value_bytes(0)..value_bytes(0) + NUM_BYTES).collect(),
table_column: Column::single(RANGE_COUNTER),
frequencies_column: Column::single(RC_FREQUENCIES),
columns: Column::singles(
BYTE_PACKING_COL_MAP.value_bytes[0]
..BYTE_PACKING_COL_MAP.value_bytes[0] + NUM_BYTES,
)
.collect(),
table_column: Column::single(BYTE_PACKING_COL_MAP.range_counter),
frequencies_column: Column::single(BYTE_PACKING_COL_MAP.rc_frequencies),
filter_columns: vec![Default::default(); NUM_BYTES],
}]
}
Expand Down
Loading

0 comments on commit 2bd9d32

Please sign in to comment.