Skip to content

Commit

Permalink
trace gen sketch
Browse files Browse the repository at this point in the history
  • Loading branch information
Okm165 committed Nov 8, 2024
1 parent 5e99570 commit d76e435
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 34 deletions.
2 changes: 0 additions & 2 deletions crates/core/executor/src/events/precompiles/poseidon2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ pub struct Poseidon2PermEvent {
pub clk: u32,
/// The pointer to the x value.
pub input_ptr: u32,
/// The x value as a list of words.
pub input: Vec<u32>,
/// The memory records for the x value.
pub input_memory_records: Vec<MemoryWriteRecord>,
/// The local memory access records.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ where
let mut state: [AB::Expr; WIDTH] = local.state.map(|x| x.into());

// Perform permutation on the state
Self::external_linear_layer::<AB>(&mut state);
Self::external_linear_layer::<AB::Expr>(&mut state);

for round in 0..NUM_FULL_ROUNDS / 2 {
Self::eval_full_round(
Expand Down Expand Up @@ -114,29 +114,29 @@ where
}

impl Poseidon2PermChip {
pub fn external_linear_layer<AB: SP1AirBuilder>(state: &mut [AB::Expr; WIDTH]) {
pub fn external_linear_layer<F: AbstractField>(state: &mut [F; WIDTH]) {
for j in (0..WIDTH).step_by(4) {
Self::apply_m_4::<AB>(&mut state[j..j + 4]);
Self::apply_m_4::<F>(&mut state[j..j + 4]);
}
let sums: [AB::Expr; 4] = core::array::from_fn(|k| {
(0..WIDTH).step_by(4).map(|j| state[j + k].clone()).sum::<AB::Expr>()
let sums: [F; 4] = core::array::from_fn(|k| {
(0..WIDTH).step_by(4).map(|j| state[j + k].clone()).sum::<F>()
});

for j in 0..WIDTH {
state[j] = state[j].clone() + sums[j % 4].clone();
}
}

pub fn internal_linear_layer<AB: SP1AirBuilder>(state: &mut [AB::Expr; WIDTH]) {
let matmul_constants: [AB::F; WIDTH] = POSEIDON2_INTERNAL_MATRIX_DIAG_16_BABYBEAR_MONTY
pub fn internal_linear_layer<F: AbstractField + Clone>(state: &mut [F; WIDTH]) {
let matmul_constants: [F; WIDTH] = POSEIDON2_INTERNAL_MATRIX_DIAG_16_BABYBEAR_MONTY
.iter()
.map(|x| AB::F::from_wrapped_u32(x.as_canonical_u32()))
.map(|x| F::from_wrapped_u32(x.as_canonical_u32()))
.collect::<Vec<_>>()
.try_into()
.unwrap();
Self::matmul_internal::<AB>(state, matmul_constants);
let monty_inverse = AB::F::from_wrapped_u32(MONTY_INVERSE.as_canonical_u32());
state.iter_mut().for_each(|i| *i = i.clone() * monty_inverse);
Self::matmul_internal(state, matmul_constants);
let monty_inverse = F::from_wrapped_u32(MONTY_INVERSE.as_canonical_u32());
state.iter_mut().for_each(|i| *i = i.clone() * monty_inverse.clone());
}

pub fn eval_full_round<AB>(
Expand All @@ -151,14 +151,14 @@ impl Poseidon2PermChip {
*s = s.clone() + *r;
Self::eval_sbox(&full_round.sbox[i], s, builder);
}
Self::external_linear_layer::<AB>(state);
Self::external_linear_layer::<AB::Expr>(state);
for (state_i, post_i) in state.iter_mut().zip(full_round.post) {
builder.assert_eq(state_i.clone(), post_i);
*state_i = post_i.into();
}
}

fn eval_partial_round<AB>(
pub fn eval_partial_round<AB>(
state: &mut [AB::Expr; WIDTH],
partial_round: &PartialRound<AB::Var>,
round_constant: &AB::F,
Expand All @@ -172,13 +172,10 @@ impl Poseidon2PermChip {
builder.assert_eq(state[0].clone(), partial_round.post_sbox);
state[0] = partial_round.post_sbox.into();

Self::internal_linear_layer::<AB>(state);
Self::internal_linear_layer::<AB::Expr>(state);
}

pub fn apply_m_4<AB>(x: &mut [AB::Expr])
where
AB: SP1AirBuilder,
{
pub fn apply_m_4<F: AbstractField>(x: &mut [F]) {
let t01 = x[0].clone() + x[1].clone();
let t23 = x[2].clone() + x[3].clone();
let t0123 = t01.clone() + t23.clone();
Expand All @@ -191,13 +188,13 @@ impl Poseidon2PermChip {
x[2] = t01233 + t23; // x[0] + x[1] + 2*x[2] + 3*x[3]
}

pub fn matmul_internal<AB: SP1AirBuilder>(
state: &mut [AB::Expr; WIDTH],
mat_internal_diag_m_1: [AB::F; WIDTH],
pub fn matmul_internal<F: AbstractField + Clone>(
state: &mut [F; WIDTH],
mat_internal_diag_m_1: [F; WIDTH],
) {
let sum: AB::Expr = state.iter().cloned().sum();
let sum: F = state.iter().cloned().sum();
for i in 0..WIDTH {
state[i] = state[i].clone() * mat_internal_diag_m_1[i];
state[i] = state[i].clone() * mat_internal_diag_m_1[i].clone();
state[i] = state[i].clone() + sum.clone();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use super::{columns::NUM_POSEIDON2PERM_COLS, Poseidon2PermChip};
use crate::syscall::precompiles::poseidon2::{permutation::columns::Poseidon2PermCols, WIDTH};
use super::{
columns::{FullRound, PartialRound, NUM_POSEIDON2PERM_COLS},
Poseidon2PermChip,
};
use crate::syscall::precompiles::poseidon2::{
permutation::columns::Poseidon2PermCols, NUM_FULL_ROUNDS, NUM_PARTIAL_ROUNDS, WIDTH,
};
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
Expand All @@ -8,7 +13,7 @@ use sp1_core_executor::{
syscalls::SyscallCode,
ExecutionRecord, Program,
};
use sp1_primitives::consts::WORD_SIZE;
use sp1_primitives::RC_16_30_U32;
use sp1_stark::air::MachineAir;
use sp1_stark::MachineRecord;
use std::borrow::BorrowMut;
Expand Down Expand Up @@ -101,21 +106,78 @@ impl Poseidon2PermChip {
) {
let cols: &mut Poseidon2PermCols<F> = row.as_mut_slice().borrow_mut();

// Decode input
let input: Vec<F> = event.input.iter().map(|e| F::from_canonical_u32(*e)).collect();

// Assign basic values to the columns.
cols.is_real = F::one();
cols.shard = F::from_canonical_u32(event.shard);
cols.clk = F::from_canonical_u32(event.clk);
cols.input_ptr = F::from_canonical_u32(event.input_ptr);

// Populate memory columns. Q!
for i in 0..(WIDTH / WORD_SIZE) {
// Populate memory columns.
for i in 0..WIDTH {
cols.input_memory[i]
.populate(MemoryRecordEnum::Write(event.input_memory_records[i]), blu);
cols.input_range_checker[i].populate(event.input_memory_records[i].prev_value);
cols.state[i] = F::from_canonical_u32(event.input_memory_records[i].prev_value);
}

// Perform permutation on the state
Self::external_linear_layer(&mut cols.state);

for round in 0..NUM_FULL_ROUNDS / 2 {
Self::populate_full_round(
&mut cols.state,
&cols.beginning_full_rounds[round],
&RC_16_30_U32[round].map(F::from_canonical_u32),
);
}

todo!();
for round in 0..NUM_PARTIAL_ROUNDS {
Self::populate_partial_round(
&mut cols.state,
&cols.partial_rounds[round],
&RC_16_30_U32[round].map(F::from_canonical_u32)[0],
);
}

for round in 0..NUM_FULL_ROUNDS / 2 {
Self::populate_full_round(
&mut cols.state,
&cols.ending_full_rounds[round],
&RC_16_30_U32[round].map(F::from_canonical_u32),
);
}
}

pub fn populate_full_round<F: PrimeField32>(
state: &mut [F; WIDTH],
full_round: &FullRound<F>,
round_constants: &[F; WIDTH],
) {
for (s, r) in state.iter_mut().zip(round_constants.iter()) {
*s = *s + *r;
Self::populate_sbox(s);
}
Self::external_linear_layer(state);
for (state_i, post_i) in state.iter_mut().zip(full_round.post) {
*state_i = post_i;
}
}

pub fn populate_partial_round<F: PrimeField32>(
state: &mut [F; WIDTH],
partial_round: &PartialRound<F>,
round_constant: &F,
) {
state[0] = state[0] + *round_constant;
Self::populate_sbox(&mut state[0]);

state[0] = partial_round.post_sbox;

Self::internal_linear_layer(state);
}

#[inline]
pub fn populate_sbox<F: PrimeField32>(x: &mut F) {
*x = x.exp_const_u64::<7>();
}
}

0 comments on commit d76e435

Please sign in to comment.