From b05185fd0bf60ad6f34e3313dd9aeb6c5a5638c1 Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Sat, 20 Jul 2024 16:48:36 -0400 Subject: [PATCH] feat: Implement Columns view for BytePackingStark --- .../src/byte_packing/byte_packing_stark.rs | 132 +++++++++++------- .../src/byte_packing/columns.rs | 79 ++++++----- 2 files changed, 123 insertions(+), 88 deletions(-) diff --git a/evm_arithmetization/src/byte_packing/byte_packing_stark.rs b/evm_arithmetization/src/byte_packing/byte_packing_stark.rs index e37195497..018c23e14 100644 --- a/evm_arithmetization/src/byte_packing/byte_packing_stark.rs +++ b/evm_arithmetization/src/byte_packing/byte_packing_stark.rs @@ -28,6 +28,7 @@ //! 256^length, and as a result a different value will be stored in memory. use core::marker::PhantomData; +use std::borrow::Borrow; use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; @@ -46,10 +47,7 @@ use starky::stark::Stark; use super::NUM_BYTES; use crate::all_stark::EvmStarkFrame; -use crate::byte_packing::columns::{ - index_len, value_bytes, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, IS_READ, LEN_INDICES_COLS, - NUM_COLUMNS, RANGE_COUNTER, RC_FREQUENCIES, TIMESTAMP, -}; +use crate::byte_packing::columns::*; use crate::witness::memory::MemoryAddress; /// Strict upper bound for the individual bytes range-check. @@ -65,7 +63,8 @@ pub(crate) fn ctl_looked_data() -> Vec> { // obtain the corresponding limb. let outputs: Vec> = (0..8) .map(|i| { - let range = value_bytes(i * 4)..value_bytes(i * 4) + 4; + let range = BYTE_PACKING_COL_MAP.value_bytes[i * 4] + ..BYTE_PACKING_COL_MAP.value_bytes[i * 4] + 4; Column::linear_combination( range .enumerate() @@ -74,34 +73,54 @@ pub(crate) fn ctl_looked_data() -> Vec> { }) .collect(); - let sequence_len: Column = Column::linear_combination( - (0..NUM_BYTES).map(|i| (index_len(i), F::from_canonical_usize(i + 1))), - ); - - Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]) - .chain([sequence_len]) - .chain(Column::singles(&[TIMESTAMP])) - .chain(outputs) - .collect() + let sequence_len: Column = Column::linear_combination((0..NUM_BYTES).map(|i| { + ( + BYTE_PACKING_COL_MAP.index_len[i], + F::from_canonical_usize(i + 1), + ) + })); + + Column::singles([ + BYTE_PACKING_COL_MAP.is_read, + BYTE_PACKING_COL_MAP.addr_context, + BYTE_PACKING_COL_MAP.addr_segment, + BYTE_PACKING_COL_MAP.addr_virtual, + ]) + .chain([sequence_len]) + .chain(Column::singles(&[BYTE_PACKING_COL_MAP.timestamp])) + .chain(outputs) + .collect() } /// CTL filter for the `BytePackingStark` looked table. pub(crate) fn ctl_looked_filter() -> Filter { // The CPU table is only interested in our sequence end rows, // since those contain the final limbs of our packed int. - Filter::new_simple(Column::sum((0..NUM_BYTES).map(index_len))) + Filter::new_simple(Column::sum( + (0..NUM_BYTES).map(|i| BYTE_PACKING_COL_MAP.index_len[i]), + )) } /// Column linear combination for the `BytePackingStark` table reading/writing /// the `i`th byte sequence from `MemoryStark`. pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { - let mut res = Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT]).collect_vec(); + let mut res = Column::singles([ + BYTE_PACKING_COL_MAP.is_read, + BYTE_PACKING_COL_MAP.addr_context, + BYTE_PACKING_COL_MAP.addr_segment, + ]) + .collect_vec(); // Compute the virtual address: `ADDR_VIRTUAL` + `sequence_len` - 1 - i. let sequence_len_minus_one = (0..NUM_BYTES) - .map(|j| (index_len(j), F::from_canonical_usize(j))) + .map(|j| { + ( + BYTE_PACKING_COL_MAP.index_len[j], + F::from_canonical_usize(j), + ) + }) .collect::>(); - let mut addr_virt_cols = vec![(ADDR_VIRTUAL, F::ONE)]; + let mut addr_virt_cols = vec![(BYTE_PACKING_COL_MAP.addr_virtual, F::ONE)]; addr_virt_cols.extend(sequence_len_minus_one); let addr_virt = Column::linear_combination_with_constant( addr_virt_cols, @@ -111,12 +130,12 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { res.push(addr_virt); // The i'th input byte being read/written. - res.push(Column::single(value_bytes(i))); + res.push(Column::single(BYTE_PACKING_COL_MAP.value_bytes[i])); // Since we're reading a single byte, the higher limbs must be zero. res.extend((1..8).map(|_| Column::zero())); - res.push(Column::single(TIMESTAMP)); + res.push(Column::single(BYTE_PACKING_COL_MAP.timestamp)); res } @@ -124,7 +143,9 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { /// CTL filter for reading/writing the `i`th byte of the byte sequence from/to /// memory. pub(crate) fn ctl_looking_memory_filter(i: usize) -> Filter { - Filter::new_simple(Column::sum((i..NUM_BYTES).map(index_len))) + Filter::new_simple(Column::sum( + (i..NUM_BYTES).map(|i| BYTE_PACKING_COL_MAP.index_len[i]), + )) } /// Information about a byte packing operation needed for witness generation. @@ -205,24 +226,24 @@ impl, const D: usize> BytePackingStark { virt, } = base_address; - let mut row = [F::ZERO; NUM_COLUMNS]; - row[IS_READ] = F::from_bool(is_read); + let mut row = BytePackingColumnsView::default(); + row.is_read = F::from_bool(is_read); - row[ADDR_CONTEXT] = F::from_canonical_usize(context); - row[ADDR_SEGMENT] = F::from_canonical_usize(segment); + row.addr_context = F::from_canonical_usize(context); + row.addr_segment = F::from_canonical_usize(segment); // We store the initial virtual segment. But the CTLs, // we start with virt + sequence_len - 1. - row[ADDR_VIRTUAL] = F::from_canonical_usize(virt); + row.addr_virtual = F::from_canonical_usize(virt); - row[TIMESTAMP] = F::from_canonical_usize(timestamp); + row.timestamp = F::from_canonical_usize(timestamp); - row[index_len(bytes.len() - 1)] = F::ONE; + row.index_len[bytes.len() - 1] = F::ONE; for (i, &byte) in bytes.iter().rev().enumerate() { - row[value_bytes(i)] = F::from_canonical_u8(byte); + row.value_bytes[i] = F::from_canonical_u8(byte); } - row + row.into() } const fn generate_padding_row(&self) -> [F; NUM_COLUMNS] { @@ -237,10 +258,11 @@ impl, const D: usize> BytePackingStark { debug_assert!(cols.iter().all(|col| col.len() == n_rows)); for i in 0..BYTE_RANGE_MAX { - cols[RANGE_COUNTER][i] = F::from_canonical_usize(i); + cols[BYTE_PACKING_COL_MAP.range_counter][i] = F::from_canonical_usize(i); } for i in BYTE_RANGE_MAX..n_rows { - cols[RANGE_COUNTER][i] = F::from_canonical_usize(BYTE_RANGE_MAX - 1); + cols[BYTE_PACKING_COL_MAP.range_counter][i] = + F::from_canonical_usize(BYTE_RANGE_MAX - 1); } // For each column c in cols, generate the range-check @@ -248,7 +270,7 @@ impl, const D: usize> BytePackingStark { // columns rc_c and rc_c+1. for col in 0..NUM_BYTES { for i in 0..n_rows { - let c = value_bytes(col); + let c = BYTE_PACKING_COL_MAP.value_bytes[col]; let x = cols[c][i].to_canonical_u64() as usize; assert!( x < BYTE_RANGE_MAX, @@ -256,7 +278,7 @@ impl, const D: usize> BytePackingStark { x, BYTE_RANGE_MAX ); - cols[RC_FREQUENCIES][x] += F::ONE; + cols[BYTE_PACKING_COL_MAP.rc_frequencies][x] += F::ONE; } } } @@ -279,13 +301,15 @@ impl, const D: usize> Stark for BytePackingSt P: PackedField, { let local_values: &[P; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap(); + let local_values: &BytePackingColumnsView

= local_values.borrow(); let next_values: &[P; NUM_COLUMNS] = vars.get_next_values().try_into().unwrap(); + let next_values: &BytePackingColumnsView

= next_values.borrow(); // Check the range column: First value must be 0, last row // must be 255, and intermediate rows must increment by 0 // or 1. - let rc1 = local_values[RANGE_COUNTER]; - let rc2 = next_values[RANGE_COUNTER]; + let rc1 = local_values.range_counter; + let rc2 = next_values.range_counter; yield_constr.constraint_first_row(rc1); let incr = rc2 - rc1; yield_constr.constraint_transition(incr * incr - incr); @@ -296,24 +320,24 @@ impl, const D: usize> Stark for BytePackingSt // We filter active columns by summing all the byte indices. // Constraining each of them to be boolean is done later on below. - let current_filter = local_values[LEN_INDICES_COLS].iter().copied().sum::

(); + let current_filter = local_values.index_len.iter().copied().sum::

(); yield_constr.constraint(current_filter * (current_filter - one)); // The filter column must start by one. yield_constr.constraint_first_row(current_filter - one); // The is_read flag must be boolean. - let current_is_read = local_values[IS_READ]; + let current_is_read = local_values.is_read; yield_constr.constraint(current_is_read * (current_is_read - one)); // Each byte index must be boolean. for i in 0..NUM_BYTES { - let idx_i = local_values[index_len(i)]; + let idx_i = local_values.index_len[i]; yield_constr.constraint(idx_i * (idx_i - one)); } // Only padding rows have their filter turned off. - let next_filter = next_values[LEN_INDICES_COLS].iter().copied().sum::

(); + let next_filter = next_values.index_len.iter().copied().sum::

(); yield_constr.constraint_transition(next_filter * (next_filter - current_filter)); // Check that all limbs after final length are 0. @@ -321,7 +345,7 @@ impl, const D: usize> Stark for BytePackingSt // If the length is i+1, then value_bytes(i+1),...,value_bytes(NUM_BYTES-1) must // be 0. for j in i + 1..NUM_BYTES { - yield_constr.constraint(local_values[index_len(i)] * local_values[value_bytes(j)]); + yield_constr.constraint(local_values.index_len[i] * local_values.value_bytes[j]); } } } @@ -334,14 +358,16 @@ impl, const D: usize> Stark for BytePackingSt ) { let local_values: &[ExtensionTarget; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap(); + let local_values: &BytePackingColumnsView> = local_values.borrow(); let next_values: &[ExtensionTarget; NUM_COLUMNS] = vars.get_next_values().try_into().unwrap(); + let next_values: &BytePackingColumnsView> = next_values.borrow(); // Check the range column: First value must be 0, last row // must be 255, and intermediate rows must increment by 0 // or 1. - let rc1 = local_values[RANGE_COUNTER]; - let rc2 = next_values[RANGE_COUNTER]; + let rc1 = local_values.range_counter; + let rc2 = next_values.range_counter; yield_constr.constraint_first_row(builder, rc1); let incr = builder.sub_extension(rc2, rc1); let t = builder.mul_sub_extension(incr, incr, incr); @@ -353,7 +379,7 @@ impl, const D: usize> Stark for BytePackingSt // We filter active columns by summing all the byte indices. // Constraining each of them to be boolean is done later on below. - let current_filter = builder.add_many_extension(&local_values[LEN_INDICES_COLS]); + let current_filter = builder.add_many_extension(local_values.index_len); let constraint = builder.mul_sub_extension(current_filter, current_filter, current_filter); yield_constr.constraint(builder, constraint); @@ -362,20 +388,20 @@ impl, const D: usize> Stark for BytePackingSt yield_constr.constraint_first_row(builder, constraint); // The is_read flag must be boolean. - let current_is_read = local_values[IS_READ]; + let current_is_read = local_values.is_read; let constraint = builder.mul_sub_extension(current_is_read, current_is_read, current_is_read); yield_constr.constraint(builder, constraint); // Each byte index must be boolean. for i in 0..NUM_BYTES { - let idx_i = local_values[index_len(i)]; + let idx_i = local_values.index_len[i]; let constraint = builder.mul_sub_extension(idx_i, idx_i, idx_i); yield_constr.constraint(builder, constraint); } // Only padding rows have their filter turned off. - let next_filter = builder.add_many_extension(&next_values[LEN_INDICES_COLS]); + let next_filter = builder.add_many_extension(next_values.index_len); let constraint = builder.sub_extension(next_filter, current_filter); let constraint = builder.mul_extension(next_filter, constraint); yield_constr.constraint_transition(builder, constraint); @@ -386,7 +412,7 @@ impl, const D: usize> Stark for BytePackingSt // be 0. for j in i + 1..NUM_BYTES { let constr = - builder.mul_extension(local_values[index_len(i)], local_values[value_bytes(j)]); + builder.mul_extension(local_values.index_len[i], local_values.value_bytes[j]); yield_constr.constraint(builder, constr); } } @@ -398,9 +424,13 @@ impl, const D: usize> Stark for BytePackingSt fn lookups(&self) -> Vec> { vec![Lookup { - columns: Column::singles(value_bytes(0)..value_bytes(0) + NUM_BYTES).collect(), - table_column: Column::single(RANGE_COUNTER), - frequencies_column: Column::single(RC_FREQUENCIES), + columns: Column::singles( + BYTE_PACKING_COL_MAP.value_bytes[0] + ..BYTE_PACKING_COL_MAP.value_bytes[0] + NUM_BYTES, + ) + .collect(), + table_column: Column::single(BYTE_PACKING_COL_MAP.range_counter), + frequencies_column: Column::single(BYTE_PACKING_COL_MAP.rc_frequencies), filter_columns: vec![Default::default(); NUM_BYTES], }] } diff --git a/evm_arithmetization/src/byte_packing/columns.rs b/evm_arithmetization/src/byte_packing/columns.rs index 45e3211de..5e394ae71 100644 --- a/evm_arithmetization/src/byte_packing/columns.rs +++ b/evm_arithmetization/src/byte_packing/columns.rs @@ -1,43 +1,48 @@ //! Byte packing registers. -use core::ops::Range; - -use crate::byte_packing::NUM_BYTES; - -/// 1 if this is a READ operation, and 0 if this is a WRITE operation. -pub(crate) const IS_READ: usize = 0; - -pub(super) const LEN_INDICES_START: usize = IS_READ + 1; -// There are `NUM_BYTES` columns used to represent the length of -// the input byte sequence for a (un)packing operation. -// index_len(i) is 1 iff the length is i+1. -pub(crate) const fn index_len(i: usize) -> usize { - debug_assert!(i < NUM_BYTES); - LEN_INDICES_START + i +use std::mem::transmute; + +use zk_evm_proc_macro::Columns; + +use crate::{byte_packing::NUM_BYTES, util::indices_arr}; + +/// A view of `BytePackingStark`'s columns. +#[repr(C)] +#[derive(Columns, Eq, PartialEq, Debug)] +pub(crate) struct BytePackingColumnsView { + /// 1 if this is a READ operation, and 0 if this is a WRITE operation. + pub is_read: T, + + // There are `NUM_BYTES` columns used to represent the length of + // the input byte sequence for a (un)packing operation. + // index_len[i] is 1 iff the length is i+1. + pub index_len: [T; NUM_BYTES], + + pub addr_context: T, + pub addr_segment: T, + pub addr_virtual: T, + pub timestamp: T, + + // 32 byte limbs hold a total of 256 bits. + // There are `NUM_BYTES` columns used to store the values of the bytes + // that are being read/written for an (un)packing operation. + pub value_bytes: [T; NUM_BYTES], + + /// The counter column (used for the logUp range check) starts from 0 and + /// increments. + pub range_counter: T, + /// The frequencies column used in logUp. + pub rc_frequencies: T, } -// Note: Those are used to obtain the length of a sequence of bytes being -// processed. -pub(crate) const LEN_INDICES_COLS: Range = LEN_INDICES_START..LEN_INDICES_START + NUM_BYTES; - -pub(crate) const ADDR_CONTEXT: usize = LEN_INDICES_START + NUM_BYTES; -pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; -pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1; -pub(crate) const TIMESTAMP: usize = ADDR_VIRTUAL + 1; - -// 32 byte limbs hold a total of 256 bits. -const BYTES_VALUES_START: usize = TIMESTAMP + 1; -// There are `NUM_BYTES` columns used to store the values of the bytes -// that are being read/written for an (un)packing operation. -pub(crate) const fn value_bytes(i: usize) -> usize { - debug_assert!(i < NUM_BYTES); - BYTES_VALUES_START + i -} +// `u8` is guaranteed to have a `size_of` of 1. +/// Number of columns in `BytePackingStark`. +pub(crate) const NUM_COLUMNS: usize = size_of::>(); -/// The counter column (used for the range check) starts from 0 and increments. -pub(crate) const RANGE_COUNTER: usize = BYTES_VALUES_START + NUM_BYTES; -/// The frequencies column used in logUp. -pub(crate) const RC_FREQUENCIES: usize = RANGE_COUNTER + 1; +const fn make_col_map() -> BytePackingColumnsView { + let indices_arr = indices_arr::(); + unsafe { transmute::<[usize; NUM_COLUMNS], BytePackingColumnsView>(indices_arr) } +} -/// Number of columns in `BytePackingStark`. -pub(crate) const NUM_COLUMNS: usize = RANGE_COUNTER + 2; +/// Map between the `BytePacking` columns and (0..`NUM_COLUMNS`) +pub(crate) const BYTE_PACKING_COL_MAP: BytePackingColumnsView = make_col_map();