Skip to content

Commit

Permalink
added consts
Browse files Browse the repository at this point in the history
  • Loading branch information
kevjue committed Dec 6, 2023
1 parent 2b1f178 commit e00c76b
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 44 deletions.
62 changes: 31 additions & 31 deletions curta/src/machine/hash/blake/blake2b/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use log::debug;
use plonky2::util::log2_ceil;

use super::data::{BLAKE2BConstNums, BLAKE2BConsts, BLAKE2BData};
use super::{BLAKE2BAir, IV};
use super::{BLAKE2BAir, COMPRESS_LENGTH, IV, STATE_SIZE};
use crate::chip::memory::instruction::MemorySliceIndex;
use crate::chip::memory::pointer::slice::Slice;
use crate::chip::memory::time::Time;
Expand All @@ -20,7 +20,8 @@ use crate::machine::hash::blake::blake2b::data::{
BLAKE2BMemory, BLAKE2BPublicData, BLAKE2BTraceData, MemoryArray,
};
use crate::machine::hash::blake::blake2b::{
COMPRESS_IV, SIGMA_PERMUTATIONS, V_INDICES, V_LAST_WRITE_AGES,
COMPRESS_IV, MIX_LENGTH, MSG_ARRAY_SIZE, NUM_MIX_ROUNDS, SIGMA_PERMUTATIONS, V_INDICES,
V_LAST_WRITE_AGES,
};
use crate::math::prelude::*;

Expand Down Expand Up @@ -205,7 +206,7 @@ where

let num_total_mix_iterations_element = builder
.constant::<ElementRegister>(&L::Field::from_canonical_usize(num_total_mix_iterations));
let mut v_indices = MemoryArray::<L, 8, 4>::new(builder);
let mut v_indices = MemoryArray::<L, MIX_LENGTH, 4>::new(builder);
for (i, indices) in V_INDICES.iter().enumerate() {
v_indices.store_row(
builder,
Expand All @@ -220,7 +221,7 @@ where
"num_total_mix_iterations_element",
);

let mut v_last_write_ages = MemoryArray::<L, 8, 4>::new(builder);
let mut v_last_write_ages = MemoryArray::<L, MIX_LENGTH, 4>::new(builder);
for (i, ages) in V_LAST_WRITE_AGES.iter().enumerate() {
v_last_write_ages.store_row(
builder,
Expand All @@ -231,15 +232,14 @@ where
);
}

let mut permutations = MemoryArray::<L, 12, 16>::new(builder);
let mut permutations = MemoryArray::<L, NUM_MIX_ROUNDS, MSG_ARRAY_SIZE>::new(builder);
let num_compresses_element = builder.constant::<ElementRegister>(
&L::Field::from_canonical_usize(num_real_compresses + num_dummy_compresses),
);
let num_full_compresses_element = builder.constant::<ElementRegister>(
&L::Field::from_canonical_usize(num_real_compresses + num_dummy_compresses - 1),
);
builder.watch(&num_compresses_element, "num_compresses_element");

builder.watch(&num_full_compresses_element, "num_full_compresses_element");

for (i, permutation) in SIGMA_PERMUTATIONS.iter().enumerate() {
Expand Down Expand Up @@ -376,15 +376,15 @@ where
let compress_index =
builder.expression(clk.expr() - compress_id.expr() * const_nums.const_96.expr());

let compress_iteration = builder.alloc::<ElementRegister>();
builder.set_to_expression_first_row(&compress_iteration, L::Field::ZERO.into());
let mix_id = builder.alloc::<ElementRegister>();
builder.set_to_expression_first_row(&mix_id, L::Field::ZERO.into());
builder.set_to_expression_transition(
&compress_iteration.next(),
cycle_8_end_bit.not_expr() * compress_iteration.expr()
&mix_id.next(),
cycle_8_end_bit.not_expr() * mix_id.expr()
+ cycle_8_end_bit.expr()
* (cycle_96_end_bit.expr() * const_nums.const_0.expr()
+ (cycle_96_end_bit.not_expr()
* (compress_iteration.expr() + const_nums.const_1.expr()))),
* (mix_id.expr() + const_nums.const_1.expr()))),
);

let at_end_compress = builder.load(
Expand Down Expand Up @@ -446,7 +446,7 @@ where
let num_full_compresses_element = builder.constant::<ElementRegister>(
&L::Field::from_canonical_usize(num_real_compresses + num_dummy_compresses - 1),
);
for i in 0..96 {
for i in 0..COMPRESS_LENGTH {
builder.store(
&save_final_v.get(i),
if i < 92 { false_const } else { true_const },
Expand Down Expand Up @@ -534,7 +534,7 @@ where
compress_id,
previous_compress_id,
compress_index,
compress_iteration,
mix_id,
mix_index,
}
}
Expand Down Expand Up @@ -624,7 +624,7 @@ where

// Each message chunk will be read 24 times per compress. Two times per compress iteration.
for (compress_id_value, padded_chunk) in padded_chunks.iter().enumerate() {
assert!(padded_chunk.len() == 16);
assert!(padded_chunk.len() == MSG_ARRAY_SIZE);
for (j, word) in padded_chunk.iter().enumerate() {
builder.store(
&m.get(compress_id_value * 16 + j),
Expand Down Expand Up @@ -700,22 +700,22 @@ where
assert!(degree_log < 31, "AIR degree is too large");
debug!("AIR degree after padding: {}", 1 << degree_log);

let num_dummy_compresses = (1 << degree_log) / 96 + 1 - num_real_compresses;
let length_last_compress = (1 << degree_log) % 96;
let num_dummy_compresses = (1 << degree_log) / COMPRESS_LENGTH + 1 - num_real_compresses;
let length_last_compress = (1 << degree_log) % COMPRESS_LENGTH;
let length_last_compress_element = builder
.constant::<ElementRegister>(&L::Field::from_canonical_usize(length_last_compress));
let num_dummy_rows = (num_dummy_compresses - 1) * 96 + length_last_compress;
let num_dummy_rows = (num_dummy_compresses - 1) * COMPRESS_LENGTH + length_last_compress;

let num_rows_element =
builder.constant::<ElementRegister>(&L::Field::from_canonical_usize(1 << degree_log));

// create the const numbers data
let const_nums = Self::blake2b_const_nums(builder);

let mut num_total_mix_iterations = (num_real_compresses + num_dummy_compresses - 1) * 12;
let num_mix_iterations_last_compress = length_last_compress / 8;
assert!(num_mix_iterations_last_compress == 4 || num_mix_iterations_last_compress == 8);
num_total_mix_iterations += num_mix_iterations_last_compress;
let mut num_total_mixes = (num_real_compresses + num_dummy_compresses - 1) * NUM_MIX_ROUNDS;
let num_mixes_last_compress = length_last_compress / MIX_LENGTH;
assert!(num_mixes_last_compress == 4 || num_mixes_last_compress == 8);
num_total_mixes += num_mixes_last_compress;

let public = BLAKE2BPublicData {
padded_chunks: padded_chunks.to_vec(),
Expand All @@ -732,8 +732,8 @@ where
num_real_compresses,
&num_real_compresses_element,
num_dummy_compresses,
num_total_mix_iterations,
num_mix_iterations_last_compress,
num_total_mixes,
num_mixes_last_compress,
&const_nums,
);

Expand Down Expand Up @@ -1087,7 +1087,7 @@ where

let mut m_idx_1 = data.consts.permutations.get_at(
builder,
data.trace.compress_iteration,
data.trace.mix_id,
permutation_col,
Some("permutation".to_string()),
);
Expand All @@ -1099,7 +1099,7 @@ where

let mut m_idx_2 = data.consts.permutations.get_at(
builder,
data.trace.compress_iteration,
data.trace.mix_id,
permutation_col,
Some("permutation".to_string()),
);
Expand Down Expand Up @@ -1211,7 +1211,7 @@ where
// If we are at the last row of compress, then compute and save the h value.

// First load the previous round's h value.
let h_workspace_1 = builder.alloc_array::<U64Register>(8);
let h_workspace_1 = builder.alloc_array::<U64Register>(STATE_SIZE);

// Read dummy h values if any of the following conditions are true
// 1) NOT at last row of a compress
Expand All @@ -1233,7 +1233,7 @@ where
&data.consts.dummy_ts,
&data.const_nums.const_0,
);
for i in 0..8 {
for i in 0..STATE_SIZE {
let i_element = builder.constant::<ElementRegister>(&L::Field::from_canonical_usize(i));
let mut h_idx = builder.expression(
data.trace.previous_compress_id.expr() * data.const_nums.const_8.expr()
Expand All @@ -1257,7 +1257,7 @@ where
}

// Xor the first 8 final v values
let h_workspace_2 = builder.alloc_array::<U64Register>(8);
let h_workspace_2 = builder.alloc_array::<U64Register>(STATE_SIZE);

// Read dummy v_final values if NOT at last row of a compress OR in a dummy compress.
//
Expand All @@ -1274,7 +1274,7 @@ where
&data.consts.dummy_ts,
&data.trace.compress_id,
);
for i in 0..8 {
for i in 0..STATE_SIZE {
let i_element = builder.constant::<ElementRegister>(&L::Field::from_canonical_usize(i));
let v_final_idx =
builder.select(read_dummy_v_final_idx, &data.consts.dummy_index, &i_element);
Expand All @@ -1289,15 +1289,15 @@ where
}

// Xor the second 8 final v values
let h = builder.alloc_array::<U64Register>(8);
let h = builder.alloc_array::<U64Register>(STATE_SIZE);

// Save h into memory if we are at the final row and it is not the end compress and not in a dummy compress.
let save_h = builder.expression(
data.trace.is_compress_final_row.expr()
* data.trace.at_end_compress.not_expr()
* data.trace.at_dummy_compress.not_expr(),
);
for i in 0..8 {
for i in 0..STATE_SIZE {
let i_element = builder.constant::<ElementRegister>(&L::Field::from_canonical_usize(i));
let i_element_plus_8 = builder.add(i_element, data.const_nums.const_8);

Expand Down
2 changes: 1 addition & 1 deletion curta/src/machine/hash/blake/blake2b/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub struct BLAKE2BTraceData {
pub(crate) compress_id: ElementRegister,
pub(crate) previous_compress_id: ElementRegister,
pub(crate) compress_index: ElementRegister,
pub(crate) compress_iteration: ElementRegister,
pub(crate) mix_id: ElementRegister,
pub(crate) mix_index: ElementRegister,
}

Expand Down
14 changes: 8 additions & 6 deletions curta/src/machine/hash/blake/blake2b/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ pub struct BLAKE2BAir<L> {
}

const NUM_MIX_ROUNDS: usize = 12;
const MIX_LENGTH: usize = 8;
const MSG_ARRAY_SIZE: usize = 16;
const HASH_ARRAY_SIZE: usize = 8;
const STATE_SIZE: usize = 8;
const WORK_VECTOR_SIZE: usize = 16;
const COMPRESS_LENGTH: usize = MIX_LENGTH * NUM_MIX_ROUNDS;

const IV: [u64; HASH_ARRAY_SIZE] = [
const IV: [u64; STATE_SIZE] = [
0x6a09e667f2bdc928,
0xbb67ae8584caa73b,
0x3c6ef372fe94f82b,
Expand All @@ -28,7 +30,7 @@ const IV: [u64; HASH_ARRAY_SIZE] = [
// we assume that the output is 32 bytes
// So that means the initial hash entry to be
// 0x6a09e667f3bcc908 xor 0x01010020
const COMPRESS_IV: [u64; HASH_ARRAY_SIZE] = [
const COMPRESS_IV: [u64; STATE_SIZE] = [
0x6a09e667f3bcc908,
0xbb67ae8584caa73b,
0x3c6ef372fe94f82b,
Expand All @@ -39,7 +41,7 @@ const COMPRESS_IV: [u64; HASH_ARRAY_SIZE] = [
0x5be0cd19137e2179,
];

const V_INDICES: [[u8; 4]; 8] = [
const V_INDICES: [[u8; 4]; MIX_LENGTH] = [
[0, 4, 8, 12],
[1, 5, 9, 13],
[2, 6, 10, 14],
Expand All @@ -50,7 +52,7 @@ const V_INDICES: [[u8; 4]; 8] = [
[3, 4, 9, 14],
];

const V_LAST_WRITE_AGES: [[u8; 4]; 8] = [
const V_LAST_WRITE_AGES: [[u8; 4]; MIX_LENGTH] = [
[4, 1, 2, 3],
[4, 5, 2, 3],
[4, 5, 6, 3],
Expand All @@ -61,7 +63,7 @@ const V_LAST_WRITE_AGES: [[u8; 4]; 8] = [
[4, 7, 6, 5],
];

const SIGMA_PERMUTATIONS: [[u8; 16]; 12] = [
const SIGMA_PERMUTATIONS: [[u8; MSG_ARRAY_SIZE]; NUM_MIX_ROUNDS] = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3],
[11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4],
Expand Down
12 changes: 6 additions & 6 deletions curta/src/machine/hash/blake/blake2b/pure.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
use super::{COMPRESS_IV, HASH_ARRAY_SIZE, WORK_VECTOR_SIZE};
use super::{COMPRESS_IV, STATE_SIZE, WORK_VECTOR_SIZE};
use crate::machine::hash::blake::blake2b::SIGMA_PERMUTATIONS;

pub struct BLAKE2BPure;

impl BLAKE2BPure {
pub fn compress(
msg_chunk: &[u8],
state: &mut [u64; HASH_ARRAY_SIZE],
state: &mut [u64; STATE_SIZE],
bytes_compressed: u64,
last_chunk: bool,
) -> [u64; HASH_ARRAY_SIZE] {
) -> [u64; STATE_SIZE] {
// Set up the work vector V
let mut v: [u64; WORK_VECTOR_SIZE] = [0; WORK_VECTOR_SIZE];

v[..8].copy_from_slice(&state[..HASH_ARRAY_SIZE]);
v[..8].copy_from_slice(&state[..STATE_SIZE]);
v[8..16].copy_from_slice(&COMPRESS_IV);

v[12] ^= bytes_compressed;
Expand Down Expand Up @@ -102,11 +102,11 @@ impl BLAKE2BPure {
);
}

for i in 0..HASH_ARRAY_SIZE {
for i in 0..STATE_SIZE {
state[i] ^= v[i];
}

for i in 0..HASH_ARRAY_SIZE {
for i in 0..STATE_SIZE {
state[i] ^= v[i + 8];
}

Expand Down

0 comments on commit e00c76b

Please sign in to comment.