Skip to content

Commit

Permalink
most multiplicities fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
kevjue committed Nov 29, 2023
1 parent 904ffba commit 9efd0fb
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 47 deletions.
109 changes: 80 additions & 29 deletions curta/src/machine/hash/blake/blake2b/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ where
dummy_index,
dummy_ts,
first_compress_h_read_ts,
length_last_round: *length_last_round,
}
}

Expand All @@ -205,7 +206,7 @@ where
end_bits: &ArrayRegister<BitRegister>,
digest_bits: &ArrayRegister<BitRegister>,
num_dummy_rounds: usize,
length_last_round: usize,
length_last_round_element: &ElementRegister,
) -> BLAKE2BTraceData {
let (cycle_3_end_bit, cycle_4_end_bit, cycle_8_end_bit, cycle_96_end_bit) =
Self::cycles_end_bits(builder);
Expand Down Expand Up @@ -233,6 +234,28 @@ where
);
}

let at_partial_compress = builder.uninit_slice();
for i in 1..end_bits.len() + num_dummy_rounds - 1 {
builder.store(
&at_partial_compress.get(i),
false_const,
&Time::zero(),
Some(const_nums.const_96),
);
}
builder.store(
&at_partial_compress.get(end_bits.len() + num_dummy_rounds - 1),
true_const,
&Time::zero(),
Some(const_nums.const_96),
);
builder.store(
&at_partial_compress.get(end_bits.len() + num_dummy_rounds),
false_const,
&Time::zero(),
Some(*length_last_round_element),
);

let digest_bit = builder.uninit_slice();
for (i, digest_bit_val) in digest_bits.iter().enumerate() {
builder.store(
Expand All @@ -250,13 +273,11 @@ where
Some(const_nums.const_96),
);
}
let length_last_round_element =
builder.constant(&L::Field::from_canonical_usize(length_last_round));
builder.store(
&digest_bit.get(num_dummy_rounds - 1 + digest_bits.len()),
false_const,
&Time::zero(),
Some(length_last_round_element),
Some(*length_last_round_element),
);

// `compress_id` is a register is computed by counting the number of cycles. We do this by
Expand Down Expand Up @@ -357,6 +378,7 @@ where
is_compress_third_row,
at_first_compress,
at_last_hash_compress,
at_partial_compress,
cycle_96_end_bit,
digest_bit,
save_final_v,
Expand Down Expand Up @@ -394,17 +416,12 @@ where
);
}

let num_h_dummy_reads = builder.alloc_public::<ElementRegister>();
builder.set_to_expression(
&num_h_dummy_reads,
num_messages.expr() * const_nums.const_184.expr(),
);
// All of the non compress initial rows will read this dummy index
builder.store(
&h.get_at(consts.dummy_index),
const_nums.const_0_u64,
&Time::from_element(consts.dummy_ts),
Some(num_h_dummy_reads),
&Time::from_element(consts.first_compress_h_read_ts),
Some(const_nums.const_184),
);

// Initialize the v memory
Expand Down Expand Up @@ -560,7 +577,7 @@ where
end_bits,
digest_bits,
num_dummy_rounds,
length_last_round,
&length_last_round_element,
);

// create the memory data
Expand Down Expand Up @@ -834,15 +851,15 @@ where
builder.select(save_final_v, &data.trace.compress_id, &data.consts.dummy_ts);

let updated_v_values = [updated_v0, updated_v1, updated_v2, updated_v3];
for (i, value) in updated_v_values.iter().enumerate() {
for (value, v_index) in updated_v_values.iter().zip(v_indices.iter()) {
builder.store(
&data.memory.v.get_at(v_indices[i]),
&data.memory.v.get_at(*v_index),
*value,
&Time::from_element(builder.clk),
None,
);

let v_final_idx = builder.select(save_final_v, &v_indices[i], &data.consts.dummy_index);
let v_final_idx = builder.select(save_final_v, v_index, &data.consts.dummy_index);
let v_final_value = builder.select(save_final_v, value, &data.const_nums.const_0_u64);

builder.store(
Expand All @@ -856,8 +873,6 @@ where
// If we are at the last cycle of the round, then compute and save the h value.

// First load the previous round's h value.
builder.watch(&previous_compress_id, "previous compress id");

let h_workspace_1 = builder.alloc_array::<U64Register>(8);
for i in 0..8 {
let mut h_value = builder.load(
Expand All @@ -869,7 +884,6 @@ where
&data.consts.iv_values.get(i),
&h_value,
);
builder.watch(&h_value, "previous h");
builder.set_to_expression(&h_workspace_1.get(i), h_value.expr());
}
// Xor the first 8 final v values
Expand All @@ -882,7 +896,6 @@ where
for i in 0..8 {
let v_i = builder.load(&data.memory.v_final.get(i), &Time::from_element(read_ts));
let updated_h = builder.xor(h_workspace_1.get(i), v_i);
builder.watch(&updated_h, "intermediate h");
builder.set_to_expression(&h_workspace_2.get(i), updated_h.expr());
}

Expand All @@ -893,14 +906,27 @@ where
);
let save_digest =
builder.expression(data.trace.cycle_96_end_bit.expr() * digest_bit.expr());
builder.watch(&save_digest, "save digest");

let h = builder.alloc_array::<U64Register>(8);
let num_h_reads = builder.mul(
data.trace.cycle_96_end_bit.as_element(),
data.const_nums.const_97,
let next_compress_id = builder.add(data.trace.compress_id, data.const_nums.const_1);
let next_compress_is_partial = builder.load(
&data.trace.at_partial_compress.get_at(next_compress_id),
&Time::zero(),
);
let length_last_round_plus_1 = builder.alloc_public::<ElementRegister>();
builder.set_to_expression(
&length_last_round_plus_1,
data.consts.length_last_round.expr() + data.const_nums.const_1.expr(),
);
let num_h_reads = builder.alloc::<ElementRegister>();
builder.set_to_expression(
&num_h_reads,
data.trace.cycle_96_end_bit.expr()
* (next_compress_is_partial.expr() * length_last_round_plus_1.expr()
+ next_compress_is_partial.not_expr() * data.const_nums.const_97.expr()),
);
for i in 0..8 {
let i_element = builder.constant::<ElementRegister>(&L::Field::from_canonical_usize(i));
let v_value = builder.load(
&data.memory.v_final.get(i + 8),
&Time::from_element(read_ts),
Expand All @@ -910,22 +936,37 @@ where
builder.watch(&xor, "final h");

// If we are at the last compress of a message, then save the IV into memory.h
let save_h_value =
builder.select(data.trace.cycle_96_end_bit, &xor, &h_workspace_1.get(i));
let save_h_value = builder.select(
data.trace.cycle_96_end_bit,
&xor,
&data.const_nums.const_0_u64,
);

let h_idx = builder.select(
data.trace.cycle_96_end_bit,
&i_element,
&data.consts.dummy_index,
);

builder.watch(&save_h_value, "save h value");
builder.store(
&data.memory.h.get(i),
&data.memory.h.get_at(h_idx),
save_h_value,
&Time::from_element(data.trace.compress_id),
Some(num_h_reads),
);
}

let num_dummy_h_reads = builder.mul(
data.trace.cycle_96_end_bit.as_element(),
data.const_nums.const_184,
let num_dummy_h_reads = builder.alloc::<ElementRegister>();
builder.set_to_expression(
&num_dummy_h_reads,
data.trace.cycle_96_end_bit.expr()
* (next_compress_is_partial.expr()
* ((data.consts.length_last_round.expr() - data.const_nums.const_4.expr())
* data.const_nums.const_2.expr())
+ next_compress_is_partial.not_expr() * data.const_nums.const_184.expr()),
);

builder.store(
&data.memory.h.get_at(data.consts.dummy_index),
data.const_nums.const_0_u64,
Expand Down Expand Up @@ -978,6 +1019,16 @@ where
&data.consts.permutations.flattened_memory.get(127),
"permutations[127]",
);
builder.watch_memory(&data.memory.h.get(0), "h[0]");
builder.watch_memory(&data.memory.h.get(7), "h[7]");
builder.watch_memory(
&data.memory.h.get_at(data.consts.dummy_index),
"h[dummy_index]",
);
builder.watch_memory(&data.memory.v.get(0), "v[0]");
builder.watch_memory(&data.memory.v.get(15), "v[15]");
builder.watch_memory(&data.memory.v_final.get(0), "v_final[0]");
builder.watch_memory(&data.memory.v_final.get(15), "v_final[15]");

hash_state_public
}
Expand Down
2 changes: 1 addition & 1 deletion curta/src/machine/hash/blake/blake2b/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub mod test_utils {
type CubicParams = GoldilocksCubicParameters;
type Instruction = UintInstruction;

const NUM_FREE_COLUMNS: usize = 1387;
const NUM_FREE_COLUMNS: usize = 1397;
const EXTENDED_COLUMNS: usize = 834;
}

Expand Down
2 changes: 2 additions & 0 deletions curta/src/machine/hash/blake/blake2b/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub struct BLAKE2BTraceData {
pub(crate) is_compress_third_row: BitRegister,
pub(crate) at_first_compress: BitRegister,
pub(crate) at_last_hash_compress: BitRegister,
pub(crate) at_partial_compress: Slice<BitRegister>,
pub(crate) cycle_96_end_bit: BitRegister,
pub(crate) digest_bit: Slice<BitRegister>,
pub(crate) save_final_v: Slice<BitRegister>,
Expand Down Expand Up @@ -59,6 +60,7 @@ pub struct BLAKE2BConsts<L: AirParameters> {
pub(crate) dummy_index: ElementRegister,
pub(crate) dummy_ts: ElementRegister,
pub(crate) first_compress_h_read_ts: ElementRegister,
pub(crate) length_last_round: ElementRegister,
}

pub struct BLAKE2BConstNums {
Expand Down
19 changes: 2 additions & 17 deletions curta/src/machine/hash/blake/blake2b/pure.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
use super::{COMPRESS_IV, HASH_ARRAY_SIZE, WORK_VECTOR_SIZE};
use crate::machine::hash::blake::blake2b::{NUM_MIX_ROUNDS, SIGMA_PERMUTATIONS};
use crate::machine::hash::blake::blake2b::SIGMA_PERMUTATIONS;

pub struct BLAKE2BPure;

impl BLAKE2BPure {
fn permute_msgs<T: Clone>(&self, arr: &[T], mix_round_num: usize) -> Vec<T> {
assert!(mix_round_num <= NUM_MIX_ROUNDS);

let permutation = SIGMA_PERMUTATIONS[mix_round_num % 10];
let mut result = vec![arr[0].clone(); arr.len()];

for (to_index, &from_index) in permutation.iter().enumerate() {
result[to_index] = arr[from_index as usize].clone();
}

result
}

pub fn compress(
msg_chunk: &[u8],
state: &mut [u64; HASH_ARRAY_SIZE],
Expand All @@ -39,9 +26,7 @@ impl BLAKE2BPure {
.map(|x| u64::from_le_bytes(x.try_into().unwrap()))
.collect::<Vec<_>>();

for i in 0..NUM_MIX_ROUNDS {
let s = SIGMA_PERMUTATIONS[i];

for s in SIGMA_PERMUTATIONS.iter() {
Self::mix(
&mut v,
0,
Expand Down

0 comments on commit 9efd0fb

Please sign in to comment.