Skip to content

Commit

Permalink
Row Data Stored On The Heap
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Feb 23, 2025
1 parent 81d1fe3 commit c30b611
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions crates/air_utils/src/trace/row_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, Uninde
use rayon::prelude::*;
use stwo_prover::core::backend::simd::m31::PackedM31;

pub type MutRow<'trace, const N: usize> = [&'trace mut PackedM31; N];
pub type MutRow<'trace, const N: usize> = Box<[&'trace mut PackedM31; N]>;

/// An iterator over mutable references to the rows of a [`super::component_trace::ComponentTrace`].
// TODO(Ohad): Iterating over single rows is not optimal, figure out optimal chunk size when using
// this iterator.
pub struct RowIterMut<'trace, const N: usize> {
v: [*mut [PackedM31]; N],
v: Box<[*mut [PackedM31]; N]>,
phantom: PhantomData<&'trace ()>,
}
impl<'trace, const N: usize> RowIterMut<'trace, N> {
pub fn new(slice: [&'trace mut [PackedM31]; N]) -> Self {
Self {
v: slice.map(|s| s as *mut _),
v: Box::new(slice.map(|s| s as *mut _)),
phantom: PhantomData,
}
}
Expand All @@ -34,7 +34,7 @@ impl<'trace, const N: usize> Iterator for RowIterMut<'trace, N> {
self.v[i] = tail;
&mut (*head)[0]
});
Some(item)
Some(Box::new(item))
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand All @@ -54,12 +54,12 @@ impl<const N: usize> DoubleEndedIterator for RowIterMut<'_, N> {
self.v[i] = head;
&mut (*tail)[0]
});
Some(item)
Some(Box::new(item))
}
}

struct RowProducer<'trace, const N: usize> {
data: [&'trace mut [PackedM31]; N],
data: Box<[&'trace mut [PackedM31]; N]>,
}
impl<'trace, const N: usize> Producer for RowProducer<'trace, N> {
type Item = MutRow<'trace, N>;
Expand All @@ -72,14 +72,21 @@ impl<'trace, const N: usize> Producer for RowProducer<'trace, N> {
left[i] = lhs;
right[i] = rhs;
}
(RowProducer { data: left }, RowProducer { data: right })
(
RowProducer {
data: Box::new(left),
},
RowProducer {
data: Box::new(right),
},
)
}

type IntoIter = RowIterMut<'trace, N>;

fn into_iter(self) -> Self::IntoIter {
RowIterMut {
v: self.data.map(|s| s as *mut _),
v: Box::new(self.data.map(|s| s as *mut _)),
phantom: PhantomData,
}
}
Expand All @@ -90,11 +97,13 @@ impl<'trace, const N: usize> Producer for RowProducer<'trace, N> {
/// array of columns, hence iterating over rows is not trivial. Iteration is done by iterating over
/// `N` columns in parallel.
pub struct ParRowIterMut<'trace, const N: usize> {
data: [&'trace mut [PackedM31]; N],
data: Box<[&'trace mut [PackedM31]; N]>,
}
impl<'trace, const N: usize> ParRowIterMut<'trace, N> {
pub(super) fn new(data: [&'trace mut [PackedM31]; N]) -> Self {
Self { data }
Self {
data: Box::new(data),
}
}
}
impl<'trace, const N: usize> ParallelIterator for ParRowIterMut<'trace, N> {
Expand Down

0 comments on commit c30b611

Please sign in to comment.