Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Columns view for BytePackingStark #422

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 81 additions & 51 deletions evm_arithmetization/src/byte_packing/byte_packing_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
//! 256^length, and as a result a different value will be stored in memory.

use core::marker::PhantomData;
use std::borrow::Borrow;

use itertools::Itertools;
use plonky2::field::extension::{Extendable, FieldExtension};
Expand All @@ -46,10 +47,7 @@ use starky::stark::Stark;

use super::NUM_BYTES;
use crate::all_stark::EvmStarkFrame;
use crate::byte_packing::columns::{
index_len, value_bytes, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, IS_READ, LEN_INDICES_COLS,
NUM_COLUMNS, RANGE_COUNTER, RC_FREQUENCIES, TIMESTAMP,
};
use crate::byte_packing::columns::*;
use crate::witness::memory::MemoryAddress;

/// Strict upper bound for the individual bytes range-check.
Expand All @@ -65,7 +63,8 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
// obtain the corresponding limb.
let outputs: Vec<Column<F>> = (0..8)
.map(|i| {
let range = value_bytes(i * 4)..value_bytes(i * 4) + 4;
let range = BYTE_PACKING_COL_MAP.value_bytes[i * 4]
..BYTE_PACKING_COL_MAP.value_bytes[i * 4] + 4;
Column::linear_combination(
range
.enumerate()
Expand All @@ -74,34 +73,54 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
})
.collect();

let sequence_len: Column<F> = Column::linear_combination(
(0..NUM_BYTES).map(|i| (index_len(i), F::from_canonical_usize(i + 1))),
);

Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL])
.chain([sequence_len])
.chain(Column::singles(&[TIMESTAMP]))
.chain(outputs)
.collect()
let sequence_len: Column<F> = Column::linear_combination((0..NUM_BYTES).map(|i| {
(
BYTE_PACKING_COL_MAP.index_len[i],
F::from_canonical_usize(i + 1),
)
}));

Column::singles([
BYTE_PACKING_COL_MAP.is_read,
BYTE_PACKING_COL_MAP.addr_context,
BYTE_PACKING_COL_MAP.addr_segment,
BYTE_PACKING_COL_MAP.addr_virtual,
])
.chain([sequence_len])
.chain(Column::singles(&[BYTE_PACKING_COL_MAP.timestamp]))
.chain(outputs)
.collect()
}

/// CTL filter for the `BytePackingStark` looked table.
pub(crate) fn ctl_looked_filter<F: Field>() -> Filter<F> {
// The CPU table is only interested in our sequence end rows,
// since those contain the final limbs of our packed int.
Filter::new_simple(Column::sum((0..NUM_BYTES).map(index_len)))
Filter::new_simple(Column::sum(
(0..NUM_BYTES).map(|i| BYTE_PACKING_COL_MAP.index_len[i]),
))
}

/// Column linear combination for the `BytePackingStark` table reading/writing
/// the `i`th byte sequence from `MemoryStark`.
pub(crate) fn ctl_looking_memory<F: Field>(i: usize) -> Vec<Column<F>> {
let mut res = Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT]).collect_vec();
let mut res = Column::singles([
BYTE_PACKING_COL_MAP.is_read,
BYTE_PACKING_COL_MAP.addr_context,
BYTE_PACKING_COL_MAP.addr_segment,
])
.collect_vec();

// Compute the virtual address: `ADDR_VIRTUAL` + `sequence_len` - 1 - i.
let sequence_len_minus_one = (0..NUM_BYTES)
.map(|j| (index_len(j), F::from_canonical_usize(j)))
.map(|j| {
(
BYTE_PACKING_COL_MAP.index_len[j],
F::from_canonical_usize(j),
)
})
.collect::<Vec<_>>();
let mut addr_virt_cols = vec![(ADDR_VIRTUAL, F::ONE)];
let mut addr_virt_cols = vec![(BYTE_PACKING_COL_MAP.addr_virtual, F::ONE)];
addr_virt_cols.extend(sequence_len_minus_one);
let addr_virt = Column::linear_combination_with_constant(
addr_virt_cols,
Expand All @@ -111,20 +130,22 @@ pub(crate) fn ctl_looking_memory<F: Field>(i: usize) -> Vec<Column<F>> {
res.push(addr_virt);

// The i'th input byte being read/written.
res.push(Column::single(value_bytes(i)));
res.push(Column::single(BYTE_PACKING_COL_MAP.value_bytes[i]));

// Since we're reading a single byte, the higher limbs must be zero.
res.extend((1..8).map(|_| Column::zero()));

res.push(Column::single(TIMESTAMP));
res.push(Column::single(BYTE_PACKING_COL_MAP.timestamp));

res
}

/// CTL filter for reading/writing the `i`th byte of the byte sequence from/to
/// memory.
pub(crate) fn ctl_looking_memory_filter<F: Field>(i: usize) -> Filter<F> {
Filter::new_simple(Column::sum((i..NUM_BYTES).map(index_len)))
Filter::new_simple(Column::sum(
(i..NUM_BYTES).map(|i| BYTE_PACKING_COL_MAP.index_len[i]),
))
}

/// Information about a byte packing operation needed for witness generation.
Expand Down Expand Up @@ -205,24 +226,24 @@ impl<F: RichField + Extendable<D>, const D: usize> BytePackingStark<F, D> {
virt,
} = base_address;

let mut row = [F::ZERO; NUM_COLUMNS];
row[IS_READ] = F::from_bool(is_read);
let mut row = BytePackingColumnsView::default();
row.is_read = F::from_bool(is_read);

row[ADDR_CONTEXT] = F::from_canonical_usize(context);
row[ADDR_SEGMENT] = F::from_canonical_usize(segment);
row.addr_context = F::from_canonical_usize(context);
einar-polygon marked this conversation as resolved.
Show resolved Hide resolved
row.addr_segment = F::from_canonical_usize(segment);
// We store the initial virtual segment. But the CTLs,
// we start with virt + sequence_len - 1.
row[ADDR_VIRTUAL] = F::from_canonical_usize(virt);
row.addr_virtual = F::from_canonical_usize(virt);

row[TIMESTAMP] = F::from_canonical_usize(timestamp);
row.timestamp = F::from_canonical_usize(timestamp);

row[index_len(bytes.len() - 1)] = F::ONE;
row.index_len[bytes.len() - 1] = F::ONE;

for (i, &byte) in bytes.iter().rev().enumerate() {
row[value_bytes(i)] = F::from_canonical_u8(byte);
row.value_bytes[i] = F::from_canonical_u8(byte);
}

row
row.into()
}

const fn generate_padding_row(&self) -> [F; NUM_COLUMNS] {
Expand All @@ -237,26 +258,27 @@ impl<F: RichField + Extendable<D>, const D: usize> BytePackingStark<F, D> {
debug_assert!(cols.iter().all(|col| col.len() == n_rows));

for i in 0..BYTE_RANGE_MAX {
cols[RANGE_COUNTER][i] = F::from_canonical_usize(i);
cols[BYTE_PACKING_COL_MAP.range_counter][i] = F::from_canonical_usize(i);
}
for i in BYTE_RANGE_MAX..n_rows {
cols[RANGE_COUNTER][i] = F::from_canonical_usize(BYTE_RANGE_MAX - 1);
cols[BYTE_PACKING_COL_MAP.range_counter][i] =
F::from_canonical_usize(BYTE_RANGE_MAX - 1);
}

// For each column c in cols, generate the range-check
// permutations and put them in the corresponding range-check
// columns rc_c and rc_c+1.
for col in 0..NUM_BYTES {
for i in 0..n_rows {
let c = value_bytes(col);
let c = BYTE_PACKING_COL_MAP.value_bytes[col];
let x = cols[c][i].to_canonical_u64() as usize;
assert!(
x < BYTE_RANGE_MAX,
"column value {} exceeds the max range value {}",
x,
BYTE_RANGE_MAX
);
cols[RC_FREQUENCIES][x] += F::ONE;
cols[BYTE_PACKING_COL_MAP.rc_frequencies][x] += F::ONE;
}
}
}
Expand All @@ -279,13 +301,15 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt
P: PackedField<Scalar = FE>,
{
let local_values: &[P; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap();
let local_values: &BytePackingColumnsView<P> = local_values.borrow();
einar-polygon marked this conversation as resolved.
Show resolved Hide resolved
let next_values: &[P; NUM_COLUMNS] = vars.get_next_values().try_into().unwrap();
let next_values: &BytePackingColumnsView<P> = next_values.borrow();

// Check the range column: First value must be 0, last row
// must be 255, and intermediate rows must increment by 0
// or 1.
let rc1 = local_values[RANGE_COUNTER];
let rc2 = next_values[RANGE_COUNTER];
let rc1 = local_values.range_counter;
let rc2 = next_values.range_counter;
yield_constr.constraint_first_row(rc1);
let incr = rc2 - rc1;
yield_constr.constraint_transition(incr * incr - incr);
Expand All @@ -296,32 +320,32 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt

// We filter active columns by summing all the byte indices.
// Constraining each of them to be boolean is done later on below.
let current_filter = local_values[LEN_INDICES_COLS].iter().copied().sum::<P>();
let current_filter = local_values.index_len.iter().copied().sum::<P>();
yield_constr.constraint(current_filter * (current_filter - one));

// The filter column must start by one.
yield_constr.constraint_first_row(current_filter - one);

// The is_read flag must be boolean.
let current_is_read = local_values[IS_READ];
let current_is_read = local_values.is_read;
yield_constr.constraint(current_is_read * (current_is_read - one));

// Each byte index must be boolean.
for i in 0..NUM_BYTES {
let idx_i = local_values[index_len(i)];
let idx_i = local_values.index_len[i];
yield_constr.constraint(idx_i * (idx_i - one));
}

// Only padding rows have their filter turned off.
let next_filter = next_values[LEN_INDICES_COLS].iter().copied().sum::<P>();
let next_filter = next_values.index_len.iter().copied().sum::<P>();
yield_constr.constraint_transition(next_filter * (next_filter - current_filter));

// Check that all limbs after final length are 0.
for i in 0..NUM_BYTES - 1 {
// If the length is i+1, then value_bytes(i+1),...,value_bytes(NUM_BYTES-1) must
// be 0.
for j in i + 1..NUM_BYTES {
yield_constr.constraint(local_values[index_len(i)] * local_values[value_bytes(j)]);
yield_constr.constraint(local_values.index_len[i] * local_values.value_bytes[j]);
}
}
}
Expand All @@ -334,14 +358,16 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt
) {
let local_values: &[ExtensionTarget<D>; NUM_COLUMNS] =
vars.get_local_values().try_into().unwrap();
let local_values: &BytePackingColumnsView<ExtensionTarget<D>> = local_values.borrow();
let next_values: &[ExtensionTarget<D>; NUM_COLUMNS] =
vars.get_next_values().try_into().unwrap();
let next_values: &BytePackingColumnsView<ExtensionTarget<D>> = next_values.borrow();

// Check the range column: First value must be 0, last row
// must be 255, and intermediate rows must increment by 0
// or 1.
let rc1 = local_values[RANGE_COUNTER];
let rc2 = next_values[RANGE_COUNTER];
let rc1 = local_values.range_counter;
let rc2 = next_values.range_counter;
yield_constr.constraint_first_row(builder, rc1);
let incr = builder.sub_extension(rc2, rc1);
let t = builder.mul_sub_extension(incr, incr, incr);
Expand All @@ -353,7 +379,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt

// We filter active columns by summing all the byte indices.
// Constraining each of them to be boolean is done later on below.
let current_filter = builder.add_many_extension(&local_values[LEN_INDICES_COLS]);
let current_filter = builder.add_many_extension(local_values.index_len);
let constraint = builder.mul_sub_extension(current_filter, current_filter, current_filter);
yield_constr.constraint(builder, constraint);

Expand All @@ -362,20 +388,20 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt
yield_constr.constraint_first_row(builder, constraint);

// The is_read flag must be boolean.
let current_is_read = local_values[IS_READ];
let current_is_read = local_values.is_read;
let constraint =
builder.mul_sub_extension(current_is_read, current_is_read, current_is_read);
yield_constr.constraint(builder, constraint);

// Each byte index must be boolean.
for i in 0..NUM_BYTES {
let idx_i = local_values[index_len(i)];
let idx_i = local_values.index_len[i];
let constraint = builder.mul_sub_extension(idx_i, idx_i, idx_i);
yield_constr.constraint(builder, constraint);
}

// Only padding rows have their filter turned off.
let next_filter = builder.add_many_extension(&next_values[LEN_INDICES_COLS]);
let next_filter = builder.add_many_extension(next_values.index_len);
let constraint = builder.sub_extension(next_filter, current_filter);
let constraint = builder.mul_extension(next_filter, constraint);
yield_constr.constraint_transition(builder, constraint);
Expand All @@ -386,7 +412,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt
// be 0.
for j in i + 1..NUM_BYTES {
let constr =
builder.mul_extension(local_values[index_len(i)], local_values[value_bytes(j)]);
builder.mul_extension(local_values.index_len[i], local_values.value_bytes[j]);
yield_constr.constraint(builder, constr);
}
}
Expand All @@ -398,9 +424,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for BytePackingSt

fn lookups(&self) -> Vec<Lookup<F>> {
vec![Lookup {
columns: Column::singles(value_bytes(0)..value_bytes(0) + NUM_BYTES).collect(),
table_column: Column::single(RANGE_COUNTER),
frequencies_column: Column::single(RC_FREQUENCIES),
columns: Column::singles(
BYTE_PACKING_COL_MAP.value_bytes[0]
..BYTE_PACKING_COL_MAP.value_bytes[0] + NUM_BYTES,
)
.collect(),
table_column: Column::single(BYTE_PACKING_COL_MAP.range_counter),
frequencies_column: Column::single(BYTE_PACKING_COL_MAP.rc_frequencies),
filter_columns: vec![Default::default(); NUM_BYTES],
}]
}
Expand Down
Loading
Loading