Skip to content

Commit

Permalink
perf: Inline some hot spots in witness generation (#715)
Browse files Browse the repository at this point in the history
* Inline

* Some more
  • Loading branch information
Nashtare authored Oct 10, 2024
1 parent 34307b4 commit 08976ab
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
2 changes: 2 additions & 0 deletions evm_arithmetization/src/memory/segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,12 @@ impl Segment {
pub(crate) const COUNT: usize = 39;

/// Unscales this segment by `SEGMENT_SCALING_FACTOR`.
#[inline(always)]
pub(crate) const fn unscale(&self) -> usize {
*self as usize >> SEGMENT_SCALING_FACTOR
}

#[inline(always)]
pub(crate) const fn all() -> [Self; Self::COUNT] {
[
Self::Code,
Expand Down
20 changes: 15 additions & 5 deletions evm_arithmetization/src/witness/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use crate::witness::errors::ProgramError;
use crate::witness::errors::ProgramError::MemoryError;

impl MemoryChannel {
pub(crate) fn index(&self) -> usize {
#[inline(always)]
pub(crate) const fn index(&self) -> usize {
match *self {
Code => 0,
GeneralPurpose(n) => {
Expand All @@ -43,6 +44,7 @@ pub struct MemoryAddress {
}

impl MemoryAddress {
#[inline(always)]
pub(crate) const fn new(context: usize, segment: Segment, virt: usize) -> Self {
Self {
context,
Expand All @@ -69,7 +71,8 @@ impl MemoryAddress {
Ok(Self::new(context, Segment::all()[segment], virt))
}

pub(crate) fn increment(&mut self) {
#[inline(always)]
pub(crate) const fn increment(&mut self) {
self.virt = self.virt.saturating_add(1);
}
}
Expand Down Expand Up @@ -104,7 +107,8 @@ pub(crate) static DUMMY_MEMOP: MemoryOp = MemoryOp {
};

impl MemoryOp {
pub(crate) fn new(
#[inline(always)]
pub(crate) const fn new(
channel: MemoryChannel,
clock: usize,
address: MemoryAddress,
Expand All @@ -123,6 +127,7 @@ impl MemoryOp {
}
}

#[inline(always)]
pub(crate) const fn new_dummy_read(
address: MemoryAddress,
timestamp: usize,
Expand All @@ -137,6 +142,7 @@ impl MemoryOp {
}
}

#[inline(always)]
pub(crate) const fn sorting_key(&self) -> (usize, usize, usize, usize) {
(
self.address.context,
Expand Down Expand Up @@ -175,6 +181,7 @@ impl MemoryState {
}
}

#[inline]
pub(crate) fn get(&self, address: MemoryAddress) -> Option<U256> {
if address.context >= self.contexts.len() {
return None;
Expand All @@ -188,7 +195,7 @@ impl MemoryState {
return None;
}
let val = self.contexts[address.context].segments[address.segment].get(address.virt);
assert!(
debug_assert!(
val.bits() <= segment.bit_range(),
"Value {} exceeds {:?} range of {} bits",
val,
Expand Down Expand Up @@ -245,14 +252,15 @@ impl MemoryState {
}
}

#[inline]
pub(crate) fn set(&mut self, address: MemoryAddress, val: U256) {
while address.context >= self.contexts.len() {
self.contexts.push(MemoryContextState::default());
}

let segment = Segment::all()[address.segment];

assert!(
debug_assert!(
val.bits() <= segment.bit_range(),
"Value {} exceeds {:?} range of {} bits",
val,
Expand Down Expand Up @@ -320,6 +328,7 @@ pub(crate) struct MemorySegmentState {
}

impl MemorySegmentState {
#[inline]
pub(crate) fn get(&self, virtual_addr: usize) -> U256 {
self.content
.get(virtual_addr)
Expand All @@ -328,6 +337,7 @@ impl MemorySegmentState {
.unwrap_or_default()
}

#[inline]
pub(crate) fn set(&mut self, virtual_addr: usize, value: U256) {
if virtual_addr >= self.content.len() {
self.content.resize(virtual_addr + 1, None);
Expand Down
2 changes: 2 additions & 0 deletions evm_arithmetization/src/witness/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ impl<T: Copy + Field> Traces<T> {
self.poseidon_ops.truncate(checkpoint.poseidon_len);
}

#[inline(always)]
pub(crate) fn mem_ops_since(&self, checkpoint: TraceCheckpoint) -> &[MemoryOp] {
&self.memory_ops[checkpoint.memory_len..]
}

#[inline(always)]
pub(crate) fn clock(&self) -> usize {
self.cpu.len()
}
Expand Down
6 changes: 5 additions & 1 deletion evm_arithmetization/src/witness/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ use crate::memory::segments::Segment;
use crate::witness::errors::ProgramError;
use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind};

#[inline(always)]
fn to_byte_checked(n: U256) -> u8 {
let res = n.byte(0);
assert_eq!(n, res.into());
res
}

#[inline(always)]
fn to_bits_le<F: RichField>(n: u8) -> [F; 8] {
let mut res = [F::ZERO; 8];
for (i, bit) in res.iter_mut().enumerate() {
Expand Down Expand Up @@ -76,7 +78,8 @@ pub(crate) fn fill_channel_with_value<F: RichField>(

/// Pushes without writing in memory. This happens in opcodes where a push
/// immediately follows a pop.
pub(crate) fn push_no_write<F: RichField>(state: &mut GenerationState<F>, val: U256) {
#[inline(always)]
pub(crate) const fn push_no_write<F: RichField>(state: &mut GenerationState<F>, val: U256) {
state.registers.stack_top = val;
state.registers.stack_len += 1;
}
Expand Down Expand Up @@ -135,6 +138,7 @@ pub(crate) fn mem_read_with_log<F: RichField>(
(val, op)
}

#[inline(always)]
pub(crate) fn mem_write_log<F: RichField>(
channel: MemoryChannel,
address: MemoryAddress,
Expand Down

0 comments on commit 08976ab

Please sign in to comment.