Skip to content

Commit

Permalink
Add LogicColumnsView struct for LogicStark (#347)
Browse files Browse the repository at this point in the history
* Add LogicColumnsView struct

* Add generate-evaluate test for logic stark

* Add comments to gen-eval test in logic stark
  • Loading branch information
gio256 authored Jul 2, 2024
1 parent 7a4a9bb commit a591b6f
Showing 1 changed file with 151 additions and 83 deletions.
234 changes: 151 additions & 83 deletions evm_arithmetization/src/logic.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::borrow::Borrow;
use core::marker::PhantomData;

use ethereum_types::U256;
Expand All @@ -18,7 +19,7 @@ use starky::stark::Stark;
use starky::util::trace_rows_to_poly_values;

use crate::all_stark::EvmStarkFrame;
use crate::logic::columns::NUM_COLUMNS;
use crate::logic::columns::{LogicColumnsView, LOGIC_COL_MAP, NUM_COLUMNS};
use crate::util::{limb_from_bits_le, limb_from_bits_le_recursive};

/// Total number of bits per input/output.
Expand All @@ -32,37 +33,50 @@ const PACKED_LEN: usize = ceil_div_usize(VAL_BITS, PACKED_LIMB_BITS);

/// `LogicStark` columns.
pub(crate) mod columns {
use core::cmp::min;
use core::ops::Range;

use super::{PACKED_LEN, PACKED_LIMB_BITS, VAL_BITS};

/// 1 if this is an AND operation, 0 otherwise.
pub(crate) const IS_AND: usize = 0;
/// 1 if this is an OR operation, 0 otherwise.
pub(crate) const IS_OR: usize = IS_AND + 1;
/// 1 if this is a XOR operation, 0 otherwise.
pub(crate) const IS_XOR: usize = IS_OR + 1;
/// First input, decomposed into bits.
pub(crate) const INPUT0: Range<usize> = (IS_XOR + 1)..(IS_XOR + 1) + VAL_BITS;
/// Second input, decomposed into bits.
pub(crate) const INPUT1: Range<usize> = INPUT0.end..INPUT0.end + VAL_BITS;
/// The result is packed in limbs of `PACKED_LIMB_BITS` bits.
pub(crate) const RESULT: Range<usize> = INPUT1.end..INPUT1.end + PACKED_LEN;

/// Returns the column range for each 32 bit chunk in the input.
pub(crate) fn limb_bit_cols_for_input(
input_bits: Range<usize>,
) -> impl Iterator<Item = Range<usize>> {
(0..PACKED_LEN).map(move |i| {
let start = input_bits.start + i * PACKED_LIMB_BITS;
let end = min(start + PACKED_LIMB_BITS, input_bits.end);
start..end
})
use core::mem::transmute;

use zk_evm_proc_macro::{Columns, DerefColumns};

use super::{PACKED_LEN, VAL_BITS};
use crate::util::indices_arr;

/// Flag columns for the operation to perform.
#[repr(C)]
#[derive(DerefColumns, Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct OpsColumnsView<T> {
/// 1 if this is an AND operation, 0 otherwise.
pub is_and: T,
/// 1 if this is an OR operation, 0 otherwise.
pub is_or: T,
/// 1 if this is a XOR operation, 0 otherwise.
pub is_xor: T,
}

/// Number of columns in `LogicStark`.
pub(crate) const NUM_COLUMNS: usize = RESULT.end;
/// Columns for the `LogicStark`.
#[repr(C)]
#[derive(Columns, Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct LogicColumnsView<T> {
/// The operation to perform.
pub op: OpsColumnsView<T>,
/// First input, decomposed into bits.
pub input0: [T; VAL_BITS],
/// Second input, decomposed into bits.
pub input1: [T; VAL_BITS],
/// The result is packed in limbs of `PACKED_LIMB_BITS` bits.
pub result: [T; PACKED_LEN],
}

/// Total number of columns in `LogicStark`.
/// `u8` is guaranteed to have a `size_of` of 1.
pub(crate) const NUM_COLUMNS: usize = core::mem::size_of::<LogicColumnsView<u8>>();

/// Mapping between [0..NUM_COLUMNS-1] and the logic columns.
pub(crate) const LOGIC_COL_MAP: LogicColumnsView<usize> = make_col_map();

const fn make_col_map() -> LogicColumnsView<usize> {
let indices_arr = indices_arr::<NUM_COLUMNS>();
unsafe { transmute::<[usize; NUM_COLUMNS], LogicColumnsView<usize>>(indices_arr) }
}
}

/// Creates the vector of `Columns` corresponding to the opcode, the two inputs
Expand All @@ -73,23 +87,29 @@ pub(crate) fn ctl_data<F: Field>() -> Vec<Column<F>> {
// will enforce that the reconstructed opcode value from the
// opcode bits matches.
let mut res = vec![Column::linear_combination([
(columns::IS_AND, F::from_canonical_u8(0x16)),
(columns::IS_OR, F::from_canonical_u8(0x17)),
(columns::IS_XOR, F::from_canonical_u8(0x18)),
(LOGIC_COL_MAP.op.is_and, F::from_canonical_u8(0x16)),
(LOGIC_COL_MAP.op.is_or, F::from_canonical_u8(0x17)),
(LOGIC_COL_MAP.op.is_xor, F::from_canonical_u8(0x18)),
])];
res.extend(columns::limb_bit_cols_for_input(columns::INPUT0).map(Column::le_bits));
res.extend(columns::limb_bit_cols_for_input(columns::INPUT1).map(Column::le_bits));
res.extend(columns::RESULT.map(Column::single));
res.extend(
LOGIC_COL_MAP
.input0
.chunks(PACKED_LIMB_BITS)
.map(Column::le_bits),
);
res.extend(
LOGIC_COL_MAP
.input1
.chunks(PACKED_LIMB_BITS)
.map(Column::le_bits),
);
res.extend(LOGIC_COL_MAP.result.map(Column::single));
res
}

/// CTL filter for logic operations.
pub(crate) fn ctl_filter<F: Field>() -> Filter<F> {
Filter::new_simple(Column::sum([
columns::IS_AND,
columns::IS_OR,
columns::IS_XOR,
]))
Filter::new_simple(Column::sum(*LOGIC_COL_MAP.op))
}

/// Structure representing the Logic STARK, which computes all logic operations.
Expand Down Expand Up @@ -151,18 +171,18 @@ impl Operation {
} = self;
let mut row = [F::ZERO; NUM_COLUMNS];
row[match operator {
Op::And => columns::IS_AND,
Op::Or => columns::IS_OR,
Op::Xor => columns::IS_XOR,
Op::And => LOGIC_COL_MAP.op.is_and,
Op::Or => LOGIC_COL_MAP.op.is_or,
Op::Xor => LOGIC_COL_MAP.op.is_xor,
}] = F::ONE;
for i in 0..256 {
row[columns::INPUT0.start + i] = F::from_bool(input0.bit(i));
row[columns::INPUT1.start + i] = F::from_bool(input1.bit(i));
row[LOGIC_COL_MAP.input0[i]] = F::from_bool(input0.bit(i));
row[LOGIC_COL_MAP.input1[i]] = F::from_bool(input1.bit(i));
}
let result_limbs: &[u64] = result.as_ref();
for (i, &limb) in result_limbs.iter().enumerate() {
row[columns::RESULT.start + 2 * i] = F::from_canonical_u32(limb as u32);
row[columns::RESULT.start + 2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32);
row[LOGIC_COL_MAP.result[2 * i]] = F::from_canonical_u32(limb as u32);
row[LOGIC_COL_MAP.result[2 * i + 1]] = F::from_canonical_u32((limb >> 32) as u32);
}
row
}
Expand Down Expand Up @@ -233,11 +253,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
let lv = vars.get_local_values();
let lv: &[P; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap();
let lv: &LogicColumnsView<P> = lv.borrow();

let is_and = lv[columns::IS_AND];
let is_or = lv[columns::IS_OR];
let is_xor = lv[columns::IS_XOR];
let is_and = lv.op.is_and;
let is_or = lv.op.is_or;
let is_xor = lv.op.is_xor;

// Flags must be boolean.
for &flag in &[is_and, is_or, is_xor] {
Expand All @@ -256,31 +277,27 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
let and_coeff = is_and - is_or - is_xor * FE::TWO;

// Ensure that all bits are indeed bits.
for input_bits_cols in [columns::INPUT0, columns::INPUT1] {
for i in input_bits_cols {
let bit = lv[i];
for input_bits in [lv.input0, lv.input1] {
for bit in input_bits {
yield_constr.constraint(bit * (bit - P::ONES));
}
}

// Form the result
for (result_col, x_bits_cols, y_bits_cols) in izip!(
columns::RESULT,
columns::limb_bit_cols_for_input(columns::INPUT0),
columns::limb_bit_cols_for_input(columns::INPUT1),
for (result_limb, x_bits, y_bits) in izip!(
lv.result,
lv.input0.chunks(PACKED_LIMB_BITS),
lv.input1.chunks(PACKED_LIMB_BITS),
) {
let x: P = limb_from_bits_le(x_bits_cols.clone().map(|col| lv[col]));
let y: P = limb_from_bits_le(y_bits_cols.clone().map(|col| lv[col]));

let x_bits = x_bits_cols.map(|i| lv[i]);
let y_bits = y_bits_cols.map(|i| lv[i]);
let x: P = limb_from_bits_le(x_bits.iter().copied());
let y: P = limb_from_bits_le(y_bits.iter().copied());

let x_land_y: P = izip!(0.., x_bits, y_bits)
.map(|(i, x_bit, y_bit)| x_bit * y_bit * FE::from_canonical_u64(1 << i))
.map(|(i, &x_bit, &y_bit)| x_bit * y_bit * FE::from_canonical_u64(1 << i))
.sum();
let x_op_y = sum_coeff * (x + y) + and_coeff * x_land_y;

yield_constr.constraint(lv[result_col] - x_op_y);
yield_constr.constraint(result_limb - x_op_y);
}
}

Expand All @@ -290,11 +307,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
vars: &Self::EvaluationFrameTarget,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let lv = vars.get_local_values();
let lv: &[ExtensionTarget<D>; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap();
let lv: &LogicColumnsView<ExtensionTarget<D>> = lv.borrow();

let is_and = lv[columns::IS_AND];
let is_or = lv[columns::IS_OR];
let is_xor = lv[columns::IS_XOR];
let is_and = lv.op.is_and;
let is_or = lv.op.is_or;
let is_xor = lv.op.is_xor;

// Flags must be boolean.
for &flag in &[is_and, is_or, is_xor] {
Expand All @@ -318,28 +336,25 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
};

// Ensure that all bits are indeed bits.
for input_bits_cols in [columns::INPUT0, columns::INPUT1] {
for i in input_bits_cols {
let bit = lv[i];
for input_bits in [lv.input0, lv.input1] {
for bit in input_bits {
let constr = builder.mul_sub_extension(bit, bit, bit);
yield_constr.constraint(builder, constr);
}
}

// Form the result
for (result_col, x_bits_cols, y_bits_cols) in izip!(
columns::RESULT,
columns::limb_bit_cols_for_input(columns::INPUT0),
columns::limb_bit_cols_for_input(columns::INPUT1),
for (result_limb, x_bits, y_bits) in izip!(
lv.result,
lv.input0.chunks(PACKED_LIMB_BITS),
lv.input1.chunks(PACKED_LIMB_BITS),
) {
let x = limb_from_bits_le_recursive(builder, x_bits_cols.clone().map(|i| lv[i]));
let y = limb_from_bits_le_recursive(builder, y_bits_cols.clone().map(|i| lv[i]));
let x_bits = x_bits_cols.map(|i| lv[i]);
let y_bits = y_bits_cols.map(|i| lv[i]);
let x = limb_from_bits_le_recursive(builder, x_bits.iter().copied());
let y = limb_from_bits_le_recursive(builder, y_bits.iter().copied());

let x_land_y = izip!(0usize.., x_bits, y_bits).fold(
builder.zero_extension(),
|acc, (i, x_bit, y_bit)| {
|acc, (i, &x_bit, &y_bit)| {
builder.arithmetic_extension(
F::from_canonical_u64(1 << i),
F::ONE,
Expand All @@ -354,7 +369,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
let x_op_y = builder.mul_add_extension(sum_coeff, y, x_op_y);
builder.mul_add_extension(and_coeff, x_land_y, x_op_y)
};
let constr = builder.sub_extension(lv[result_col], x_op_y);
let constr = builder.sub_extension(result_limb, x_op_y);
yield_constr.constraint(builder, constr);
}
}
Expand All @@ -371,9 +386,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F,
#[cfg(test)]
mod tests {
use anyhow::Result;
use itertools::Itertools;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use starky::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};

use super::*;
use crate::logic::LogicStark;

#[test]
Expand Down Expand Up @@ -401,4 +421,52 @@ mod tests {
};
test_stark_circuit_constraints::<F, C, S, D>(stark)
}

#[test]
fn test_generate_eval_consistency() {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type S = LogicStark<F, D>;

let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
const N_ITERS: usize = 1000;

for _ in 0..N_ITERS {
for op in [Op::And, Op::Or, Op::Xor] {
// Generate a trace row from an operation on random values.
let operation = Operation::new(op, U256(rng.gen()), U256(rng.gen()));
let expected = operation.result;
let row = operation.into_row::<F>();
let lv = EvmStarkFrame::from_values(&row, &[F::ZERO; NUM_COLUMNS], &[]);

let stark = S::default();
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
F::ONE,
F::ONE,
F::ONE,
);

// Evaluate constraints.
stark.eval_packed_generic(&lv, &mut constraint_consumer);
for acc in constraint_consumer.accumulators() {
assert_eq!(acc, F::ZERO);
}

// Split each expected U256 limb into two.
let expected_limbs = expected.as_ref().iter().flat_map(|&limb| {
[
F::from_canonical_u32(limb as u32),
F::from_canonical_u32((limb >> 32) as u32),
]
});

// Check that the result limbs match the expected limbs.
assert!(expected_limbs
.zip_eq(&row[LOGIC_COL_MAP.result[0]..])
.all(|(x, &y)| x == y));
}
}
}
}

0 comments on commit a591b6f

Please sign in to comment.