diff --git a/constensor-core/src/cpu_storage/mod.rs b/constensor-core/src/cpu_storage/mod.rs index bf9ffb8..4b62f0a 100644 --- a/constensor-core/src/cpu_storage/mod.rs +++ b/constensor-core/src/cpu_storage/mod.rs @@ -1,7 +1,5 @@ use petgraph::algo::toposort; use petgraph::graphmap::DiGraphMap; -use std::cell::RefCell; -use std::rc::Rc; use std::{borrow::Cow, marker::PhantomData}; use pool::{BufferPool, PooledBuffer}; @@ -15,11 +13,13 @@ use crate::{ storage::{BackendDevice, BackendStorage}, CompiledGraph, DType, GraphNode, Op, Result, }; -use rand::rng; use rand::Rng; use rand_distr::{Distribution, Normal}; mod pool; +// Concurrency primitives for dynamic DAG scheduler +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{mpsc, Arc, Mutex, RwLock}; pub struct CpuDevice; @@ -91,215 +91,344 @@ impl BackendDevice for CpuDevice { }) } - fn run_graph( + fn run_graph( &self, graph: &CompiledGraph, ) -> Result> { - { - // Create a shared buffer pool - let pool = Rc::new(RefCell::new(BufferPool::::new())); + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::{mpsc, Arc, Mutex, RwLock}; - #[allow(irrefutable_let_patterns)] - let CompiledGraph::Cpu { - order, - graph, - ghost: _, - } = graph - else { - unreachable!() - }; - - // Prepare storage for intermediate results - let mut results: Vec>> = Vec::with_capacity(graph.len()); - results.resize_with(graph.len(), || None); - let mut results_strides: Vec>> = Vec::with_capacity(graph.len()); - results_strides.resize_with(graph.len(), || None); + // Thread-safe buffer pool + let pool: Arc>> = Arc::new(Mutex::new(BufferPool::new())); - let mut rng = rng(); + // Extract the compiled node list + #[allow(irrefutable_let_patterns)] + let CompiledGraph::Cpu { + graph: node_graph, .. + } = graph + else { + unreachable!("Expected CPU compiled graph"); + }; + // Clone into an Arc for sharing + let node_graph = Arc::new(node_graph.clone()); + let n = node_graph.len(); - // Evaluate nodes in topological order - for idx in order.clone() { - let op = &graph[idx]; + // Prepare slots for results and strides + let results: Arc>>>> = + Arc::new((0..n).map(|_| RwLock::new(None)).collect()); + let results_strides: Arc>>>> = + Arc::new((0..n).map(|_| RwLock::new(None)).collect()); - let out_shape = &op.shape; - let out_elem_count: usize = out_shape.iter().product(); - - let computed = match &op.op { - Op::BinaryOp { - l_id, - r_id, - operator, - } => { - if l_id.is_inplace() { - let mut l_buf = results[l_id.get()].take().unwrap(); - let r_buf = results[r_id.get()].as_ref().unwrap(); - T::binary_simd_op_inplace_lhs(&mut l_buf, r_buf, *operator); - l_buf - } else if r_id.is_inplace() { - let mut r_buf = results[r_id.get()].take().unwrap(); - let l_buf = results[l_id.get()].as_ref().unwrap(); - T::binary_simd_op_inplace_rhs(l_buf, &mut r_buf, *operator); - r_buf - } else { - let l_buf = results[l_id.get()].as_ref().unwrap(); - let r_buf = results[r_id.get()].as_ref().unwrap(); - let mut out = pool.borrow_mut().get_buffer(out_elem_count); - T::binary_simd_op(l_buf, r_buf, &mut out, *operator); - PooledBuffer::new(out, pool.clone()) - } - } - Op::Fill { v } => { - let mut buf = pool.borrow_mut().get_empty_buffer(out_elem_count); - buf.extend(std::iter::repeat_n(*v, out_elem_count)); - PooledBuffer::new(buf, pool.clone()) - } - Op::Arange { start, step, stop } => { - let mut buf = pool.borrow_mut().get_empty_buffer(out_elem_count); - let mut x = start.to_f64(); - while x < stop.to_f64() { - buf.push(T::from_f64(x)); - x += step.to_f64(); - } - PooledBuffer::new(buf, pool.clone()) - } - Op::Rand => { - let mut buf = pool.borrow_mut().get_buffer(out_elem_count); - for elt in &mut buf { - *elt = T::from_f64(rng.random()); - } - PooledBuffer::new(buf, pool.clone()) - } - Op::Randn { mean, std } => { - let mean_f = mean.to_f64(); - let std_f = std.to_f64(); - let normal = Normal::new(mean_f, std_f).unwrap(); - let mut buf = pool.borrow_mut().get_buffer(out_elem_count); - for elt in &mut buf { - *elt = T::from_f64(normal.sample(&mut rng)); - } - PooledBuffer::new(buf, pool.clone()) - } - Op::UnaryOp { v_id, operator } => { - let buf = results[v_id.get()].as_ref().unwrap(); - let op_fn = operator.to_closure(); - let mut out = pool.borrow_mut().get_buffer(out_elem_count); - out.par_iter_mut() - .zip(&**buf) - .for_each(|(out, x): (&mut T, &T)| *out = op_fn(*x)); - PooledBuffer::new(out, pool.clone()) + // Build adjacency: children lists and indegree counts + let mut children = vec![Vec::new(); n]; + let indegree_vec = (0..n).map(|_| AtomicUsize::new(0)).collect::>(); + for node in node_graph.iter() { + let dst = node.id.get(); + match &node.op { + Op::BinaryOp { l_id, r_id, .. } => { + let p1 = l_id.get(); + let p2 = r_id.get(); + children[p1].push(dst); + children[p2].push(dst); + indegree_vec[dst].fetch_add(2, Ordering::SeqCst); + } + Op::UnaryOp { v_id, .. } => { + let p = v_id.get(); + children[p].push(dst); + indegree_vec[dst].fetch_add(1, Ordering::SeqCst); + } + Op::FusedMulAdd { a_id, b_id, c_id } => { + for &p in &[a_id.get(), b_id.get(), c_id.get()] { + children[p].push(dst); + indegree_vec[dst].fetch_add(1, Ordering::SeqCst); } - Op::FusedMulAdd { a_id, b_id, c_id } => { - if a_id.is_inplace() { - let mut a_buf = results[a_id.get()].take().unwrap(); - let b_buf = results[b_id.get()].as_ref().unwrap(); - let c_buf = results[c_id.get()].as_ref().unwrap(); - - T::fma_op_inplace_a(&mut a_buf, b_buf, c_buf); - a_buf - } else if b_id.is_inplace() { - let mut b_buf = results[b_id.get()].take().unwrap(); - let a_buf = results[a_id.get()].as_ref().unwrap(); - let c_buf = results[c_id.get()].as_ref().unwrap(); - - T::fma_op_inplace_b(a_buf, &mut b_buf, c_buf); - b_buf - } else if c_id.is_inplace() { - let mut c_buf = results[c_id.get()].take().unwrap(); - let a_buf = results[a_id.get()].as_ref().unwrap(); - let b_buf = results[b_id.get()].as_ref().unwrap(); - - T::fma_op_inplace_c(a_buf, b_buf, &mut c_buf); - c_buf - } else { - let a_buf = results[a_id.get()].as_ref().unwrap(); - let b_buf = results[b_id.get()].as_ref().unwrap(); - let c_buf = results[c_id.get()].as_ref().unwrap(); - - let mut out = pool.borrow_mut().get_buffer(out_elem_count); - T::fma_op(a_buf, b_buf, c_buf, &mut out); - PooledBuffer::new(out, pool.clone()) - } + } + Op::MatMul { + l_id, r_id, o_id, .. + } => { + let p1 = l_id.get(); + let p2 = r_id.get(); + children[p1].push(dst); + children[p2].push(dst); + indegree_vec[dst].fetch_add(2, Ordering::SeqCst); + if let Some(o) = o_id { + let p3 = o.get(); + children[p3].push(dst); + indegree_vec[dst].fetch_add(1, Ordering::SeqCst); } - // Matrix multiplication: multiply two 2D tensors A (m x k) and B (k x n) - Op::MatMul { - l_id, - r_id, - o_id, - k, - alpha, - beta, - } => { - // Determine output dimensions from shape S (must be 2D) - let shape = out_shape; - assert!(shape.len() == 3); - let b = shape[0]; - let m = shape[1]; - let n = shape[2]; + } + Op::Permute { v_id } => { + let p = v_id.get(); + children[p].push(dst); + indegree_vec[dst].fetch_add(1, Ordering::SeqCst); + } + _ => {} + } + } + let indegree = Arc::new(indegree_vec); + let children = Arc::new(children); - let (mut out, out_stride) = if let Some(o_id) = o_id { - if o_id.is_inplace() { - let out_strides = results_strides[o_id.get()].as_ref().unwrap(); - (results[o_id.get()].take().unwrap(), out_strides.clone()) - } else { - let o_buf = results[o_id.get()].as_ref().unwrap(); - let out_strides = results_strides[o_id.get()].as_ref().unwrap(); - ( - PooledBuffer::new((*o_buf).clone(), pool.clone()), - out_strides.clone(), - ) - } - } else { - ( - PooledBuffer::new( - pool.borrow_mut().get_buffer(b * m * n), - pool.clone(), - ), - contiguous_strides(&[b, m, n]), - ) - }; + // Channel to signal when the final node completes + let final_idx = n - 1; + let (tx, rx) = mpsc::channel(); - let a_buf = results[l_id.get()].as_ref().unwrap(); - let b_buf = results[r_id.get()].as_ref().unwrap(); + // Spawn initial tasks for nodes with zero indegree + for idx in 0..n { + if indegree[idx].load(Ordering::SeqCst) == 0 { + let pool = pool.clone(); + let node_graph = node_graph.clone(); + let results = results.clone(); + let results_strides = results_strides.clone(); + let indegree = indegree.clone(); + let children = children.clone(); + let tx = tx.clone(); + rayon::spawn(move || { + eval_node( + idx, + &node_graph, + &pool, + &results, + &results_strides, + &indegree, + &children, + final_idx, + tx, + ); + }); + } + } + // Drop the extra sender in main thread + drop(tx); - let a_strides = results_strides[l_id.get()].as_ref().unwrap(); - let b_strides = results_strides[r_id.get()].as_ref().unwrap(); + // Wait for the final node to complete + rx.recv() + .expect("Failed to receive completion of final node"); - T::launch_gemm( - a_buf, - a_strides, - b_buf, - b_strides, - b, - m, - n, - *k, - &mut out, - &out_stride, - *alpha, - *beta, - ); + // Extract and return the final result + let mut final_lock = results[final_idx].write().unwrap(); + let pooled = final_lock.take().expect("Final result missing"); + let output = pooled.into_inner(); + Ok(CpuStorage(output)) + } +} - out - } - Op::NoOp => unreachable!("NoOp should not be evaluated."), - Op::Permute { v_id } => { - if v_id.is_inplace() { - results[v_id.get()].take().unwrap() - } else { - let buf = results[v_id.get()].as_ref().unwrap(); - PooledBuffer::new((*buf).clone(), pool.clone()) - } - } - }; +/// Recursively evaluate a node, scheduling its children when their dependencies are ready. +#[allow(clippy::too_many_arguments)] +fn eval_node( + idx: usize, + node_graph: &Arc>>, + pool: &Arc>>, + results: &Arc>>>>, + results_strides: &Arc>>>>, + indegree: &Arc>, + children: &Arc>>, + final_idx: usize, + tx: mpsc::Sender<()>, +) { + // Prepare RNG for random ops + let mut rng = rand::rng(); + let node = &node_graph[idx]; + let out_shape = &node.shape; + let out_elem_count: usize = out_shape.iter().product(); - results[idx] = Some(computed); - results_strides[idx] = Some(op.strides.clone()); + // Compute this node's buffer + let computed: PooledBuffer = match &node.op { + Op::Fill { v } => { + let mut buf = pool.lock().unwrap().get_empty_buffer(out_elem_count); + buf.extend(std::iter::repeat_n(*v, out_elem_count)); + PooledBuffer::new(buf, pool.clone()) + } + Op::Arange { start, step, stop } => { + let mut buf = pool.lock().unwrap().get_empty_buffer(out_elem_count); + let mut x = start.to_f64(); + while x < stop.to_f64() { + buf.push(T::from_f64(x)); + x += step.to_f64(); } - - // Extract final result - let final_idx = graph.len() - 1; - let output = results[final_idx].take().unwrap().into_inner(); - Ok(CpuStorage(output)) + PooledBuffer::new(buf, pool.clone()) + } + Op::Rand => { + let mut buf = pool.lock().unwrap().get_buffer(out_elem_count); + for elt in &mut buf { + *elt = T::from_f64(rng.random()); + } + PooledBuffer::new(buf, pool.clone()) + } + Op::Randn { mean, std } => { + let mean_f = mean.to_f64(); + let std_f = std.to_f64(); + let normal = Normal::new(mean_f, std_f).unwrap(); + let mut buf = pool.lock().unwrap().get_buffer(out_elem_count); + for elt in &mut buf { + *elt = T::from_f64(normal.sample(&mut rng)); + } + PooledBuffer::new(buf, pool.clone()) + } + Op::UnaryOp { v_id, operator } => { + let src_guard = results[v_id.get()].read().unwrap(); + let src = src_guard.as_ref().unwrap(); + let op_fn = operator.to_closure(); + let mut out = pool.lock().unwrap().get_buffer(out_elem_count); + out.par_iter_mut() + .zip(&**src) + .for_each(|(o, x)| *o = op_fn(*x)); + PooledBuffer::new(out, pool.clone()) + } + Op::BinaryOp { + l_id, + r_id, + operator, + } => { + if l_id.is_inplace() { + let mut left = results[l_id.get()].write().unwrap().take().unwrap(); + let right_guard = results[r_id.get()].read().unwrap(); + let right = right_guard.as_ref().unwrap(); + T::binary_simd_op_inplace_lhs(&mut left, right, *operator); + left + } else if r_id.is_inplace() { + let mut right = results[r_id.get()].write().unwrap().take().unwrap(); + let left_guard = results[l_id.get()].read().unwrap(); + let left = left_guard.as_ref().unwrap(); + T::binary_simd_op_inplace_rhs(left, &mut right, *operator); + right + } else { + let left_guard = results[l_id.get()].read().unwrap(); + let left = left_guard.as_ref().unwrap(); + let right_guard = results[r_id.get()].read().unwrap(); + let right = right_guard.as_ref().unwrap(); + let mut out = pool.lock().unwrap().get_buffer(out_elem_count); + T::binary_simd_op(left, right, &mut out, *operator); + PooledBuffer::new(out, pool.clone()) + } + } + Op::FusedMulAdd { a_id, b_id, c_id } => { + if a_id.is_inplace() { + let mut a_buf = results[a_id.get()].write().unwrap().take().unwrap(); + let b_guard = results[b_id.get()].read().unwrap(); + let b_buf = b_guard.as_ref().unwrap(); + let c_guard = results[c_id.get()].read().unwrap(); + let c_buf = c_guard.as_ref().unwrap(); + T::fma_op_inplace_a(&mut a_buf, b_buf, c_buf); + a_buf + } else if b_id.is_inplace() { + let mut b_buf = results[b_id.get()].write().unwrap().take().unwrap(); + let a_guard = results[a_id.get()].read().unwrap(); + let a_buf = a_guard.as_ref().unwrap(); + let c_guard = results[c_id.get()].read().unwrap(); + let c_buf = c_guard.as_ref().unwrap(); + T::fma_op_inplace_b(a_buf, &mut b_buf, c_buf); + b_buf + } else if c_id.is_inplace() { + let mut c_buf = results[c_id.get()].write().unwrap().take().unwrap(); + let a_guard = results[a_id.get()].read().unwrap(); + let a_buf = a_guard.as_ref().unwrap(); + let b_guard = results[b_id.get()].read().unwrap(); + let b_buf = b_guard.as_ref().unwrap(); + T::fma_op_inplace_c(a_buf, b_buf, &mut c_buf); + c_buf + } else { + let a_guard = results[a_id.get()].read().unwrap(); + let a_buf = a_guard.as_ref().unwrap(); + let b_guard = results[b_id.get()].read().unwrap(); + let b_buf = b_guard.as_ref().unwrap(); + let c_guard = results[c_id.get()].read().unwrap(); + let c_buf = c_guard.as_ref().unwrap(); + let mut out = pool.lock().unwrap().get_buffer(out_elem_count); + T::fma_op(a_buf, b_buf, c_buf, &mut out); + PooledBuffer::new(out, pool.clone()) + } + } + Op::MatMul { + l_id, + r_id, + o_id, + k, + alpha, + beta, + } => { + let shape = &node.shape; + let b = shape[0]; + let m = shape[1]; + let n = shape[2]; + let (mut out_buf, out_stride) = if let Some(o) = o_id { + if o.is_inplace() { + let buf = results[o.get()].write().unwrap().take().unwrap(); + let st = results_strides[o.get()] + .read() + .unwrap() + .as_ref() + .unwrap() + .clone(); + (buf, st) + } else { + let buf_guard = results[o.get()].read().unwrap(); + let buf_clone = buf_guard.as_ref().unwrap(); + let st_guard = results_strides[o.get()].read().unwrap(); + let st = st_guard.as_ref().unwrap().clone(); + (PooledBuffer::new((*buf_clone).clone(), pool.clone()), st) + } + } else { + let st = contiguous_strides(&[b, m, n]); + let buf = pool.lock().unwrap().get_buffer(b * m * n); + (PooledBuffer::new(buf, pool.clone()), st) + }; + let a_guard = results[l_id.get()].read().unwrap(); + let a_buf = a_guard.as_ref().unwrap(); + let b_guard = results[r_id.get()].read().unwrap(); + let b_buf = b_guard.as_ref().unwrap(); + let a_str_guard = results_strides[l_id.get()].read().unwrap(); + let a_str = a_str_guard.as_ref().unwrap(); + let b_str_guard = results_strides[r_id.get()].read().unwrap(); + let b_str = b_str_guard.as_ref().unwrap(); + T::launch_gemm( + a_buf, + a_str, + b_buf, + b_str, + b, + m, + n, + *k, + &mut out_buf, + &out_stride, + *alpha, + *beta, + ); + out_buf + } + Op::Permute { v_id } => { + if v_id.is_inplace() { + results[v_id.get()].write().unwrap().take().unwrap() + } else { + let buf_guard = results[v_id.get()].read().unwrap(); + let buf = buf_guard.as_ref().unwrap(); + PooledBuffer::new((*buf).clone(), pool.clone()) + } + } + Op::NoOp => panic!("NoOp should not be evaluated"), + }; + // store result and strides + *results[idx].write().unwrap() = Some(computed); + *results_strides[idx].write().unwrap() = Some(node.strides.clone()); + // signal final + if idx == final_idx { + let _ = tx.send(()); + } + // schedule children + for &child in &children[idx] { + if indegree[child].fetch_sub(1, Ordering::SeqCst) == 1 { + let pool2 = pool.clone(); + let ng2 = node_graph.clone(); + let res2 = results.clone(); + let rs2 = results_strides.clone(); + let indeg2 = indegree.clone(); + let ch2 = children.clone(); + let tx2 = tx.clone(); + rayon::spawn(move || { + eval_node( + child, &ng2, &pool2, &res2, &rs2, &indeg2, &ch2, final_idx, tx2, + ); + }); } } } diff --git a/constensor-core/src/cpu_storage/pool.rs b/constensor-core/src/cpu_storage/pool.rs index 5db02ec..84154b2 100644 --- a/constensor-core/src/cpu_storage/pool.rs +++ b/constensor-core/src/cpu_storage/pool.rs @@ -1,5 +1,5 @@ use std::mem; -use std::{cell::RefCell, rc::Rc}; +use std::sync::{Arc, Mutex}; use crate::DType; @@ -30,8 +30,8 @@ pub struct BufferPool { pub metrics: PoolMetrics, } -/// Shared reference to a BufferPool for automatic recycling. -pub type SharedPool = Rc>>; +/// Shared, thread-safe reference to a BufferPool for automatic recycling. +pub type SharedPool = Arc>>; #[derive(Debug)] /// Wrapper around Vec that returns its buffer to the pool on drop. @@ -74,7 +74,9 @@ impl Drop for PooledBuffer { fn drop(&mut self) { if let Some(pool) = self.pool.take() { let buf = std::mem::take(&mut self.buf); - pool.borrow_mut().recycle_buffer(buf); + // Return the buffer to the pool + let mut pool_guard = pool.lock().unwrap(); + pool_guard.recycle_buffer(buf); } } } diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index be6966a..6a9bd68 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -1,5 +1,5 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; use std::{ - cell::Cell, collections::HashMap, env, fmt::Display, @@ -8,7 +8,6 @@ use std::{ marker::PhantomData, path::Path, process::Command, - rc::Rc, sync::{Arc, RwLock, RwLockReadGuard}, }; @@ -779,11 +778,11 @@ pub enum Op { NoOp, } -#[derive(Clone, PartialEq, Debug, Eq)] +#[derive(Clone, Debug)] /// Graph tensor IDs can be cloned. pub enum GraphTensorId { - OutOfPlace(Rc>), - InPlace(Rc>), + OutOfPlace(Arc), + InPlace(Arc), } impl Hash for GraphTensorId { @@ -794,35 +793,39 @@ impl Hash for GraphTensorId { impl GraphTensorId { pub fn out_of_place(value: usize) -> Self { - Self::OutOfPlace(Rc::new(Cell::new(value))) + Self::OutOfPlace(Arc::new(AtomicUsize::new(value))) } pub fn inplace(value: usize) -> Self { - Self::InPlace(Rc::new(Cell::new(value))) + Self::InPlace(Arc::new(AtomicUsize::new(value))) } pub fn to_inplace(&self) -> Self { match self { - Self::OutOfPlace(x) | Self::InPlace(x) => Self::inplace(x.get()), + Self::OutOfPlace(x) | Self::InPlace(x) => Self::inplace(x.load(Ordering::SeqCst)), } } pub fn to_inplace_if(&self, predicate: bool) -> Self { match self { - Self::OutOfPlace(x) | Self::InPlace(x) if predicate => Self::inplace(x.get()), + Self::OutOfPlace(x) | Self::InPlace(x) if predicate => { + Self::inplace(x.load(Ordering::SeqCst)) + } _ => self.clone(), } } pub fn get(&self) -> usize { match self { - GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.get(), + GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.load(Ordering::SeqCst), } } pub fn set(&self, value: usize) { match self { - GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.set(value), + GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => { + x.store(value, Ordering::SeqCst); + } } } @@ -830,3 +833,12 @@ impl GraphTensorId { matches!(self, Self::InPlace(_)) } } + +// Manually implement equality by comparing the numeric IDs and in‐place flag: +impl PartialEq for GraphTensorId { + fn eq(&self, other: &Self) -> bool { + self.get() == other.get() && self.is_inplace() == other.is_inplace() + } +} + +impl Eq for GraphTensorId {}