From e427567598c24dd76e36ff69806f58836a85b45b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 13:38:08 -0400 Subject: [PATCH 01/11] First steps --- constensor-core/src/cpu_storage/mod.rs | 4 + constensor-core/src/graph.rs | 18 +-- constensor-core/src/tensor/concretetensor.rs | 75 +++++++---- constensor-core/src/tensor/graphtensor.rs | 125 ++++++++++++++----- constensor-core/src/tensor/mod.rs | 17 +++ 5 files changed, 176 insertions(+), 63 deletions(-) diff --git a/constensor-core/src/cpu_storage/mod.rs b/constensor-core/src/cpu_storage/mod.rs index d94e9ce..97ba700 100644 --- a/constensor-core/src/cpu_storage/mod.rs +++ b/constensor-core/src/cpu_storage/mod.rs @@ -9,6 +9,7 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelI use crate::device::Dev; use crate::storage::Storage; +use crate::tensor::is_contiguous_strides; use crate::Shape; use crate::{ storage::{BackendDevice, BackendStorage}, @@ -112,6 +113,9 @@ impl BackendDevice for CpuDevice { let out_shape = &op.shape; let out_elem_count: usize = out_shape.iter().product(); + let strides = &op.strides; + let is_contiguous = is_contiguous_strides(strides, out_shape); + let computed = match &op.op { Op::BinaryOp { l_id, diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 6c7d2c7..e692452 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -21,6 +21,7 @@ use petgraph::{dot::Dot, graph::NodeIndex}; pub struct GraphNode { pub op: Op, pub shape: Vec, + pub strides: Vec, } #[derive(Clone)] @@ -44,10 +45,11 @@ impl Graph { } /// Append an operation to the graph - pub(crate) fn add_op(&self, op: Op) { + pub(crate) fn add_op(&self, op: Op, strides: &[usize]) { self.data.write().unwrap().push(GraphNode { op, shape: S::shape(), + strides: strides.to_vec(), }); } @@ -232,7 +234,7 @@ impl Graph { let v = operator.as_closure()(*v1, *v2); new_ops[i] = GraphNode { op: Op::Fill { v }, - shape: node.shape.clone(), + ..node.clone() }; } } @@ -244,7 +246,7 @@ impl Graph { let v = operator.to_closure()(*v0); new_ops[i] = GraphNode { op: Op::Fill { v }, - shape: node.shape.clone(), + ..node.clone() }; } } @@ -285,11 +287,11 @@ impl Graph { b_id: b_id.clone(), c_id: rhs_add.clone(), }, - shape: x.shape.clone(), + ..x.clone() }; new_ops[x_id] = GraphNode { op: Op::NoOp, - shape: x.shape.clone(), + ..x.clone() }; // Look for ops which actually use this one @@ -404,7 +406,7 @@ impl Graph { r_id: r_id.clone().to_inplace_if(&target == r_id), operator: *operator, }, - shape: op.shape.clone(), + ..op.clone() }; } } @@ -437,7 +439,7 @@ impl Graph { b_id: b_id.clone().to_inplace_if(&out == b_id), c_id: c_id.clone().to_inplace_if(&out == c_id), }, - shape: op.shape.clone(), + ..op.clone() }; } } @@ -474,7 +476,7 @@ impl Graph { alpha: *alpha, beta: *beta, }, - shape: op.shape.clone(), + ..op.clone() }; } } diff --git a/constensor-core/src/tensor/concretetensor.rs b/constensor-core/src/tensor/concretetensor.rs index 3b07361..b925b36 100644 --- a/constensor-core/src/tensor/concretetensor.rs +++ b/constensor-core/src/tensor/concretetensor.rs @@ -9,9 +9,12 @@ use crate::device::Cuda; use std::{borrow::Cow, marker::PhantomData, ops::Deref, sync::Arc}; +use super::contiguous_strides; + #[derive(Clone)] pub struct Tensor_ { storage: Arc>, + strides: Vec, _ghost: PhantomData<(S, T, D)>, } @@ -28,11 +31,27 @@ impl Deref for Tensor { } } +/// Create a Tensor from storage with its default (contiguous) strides. pub(crate) fn from_storage( storage: Arc>, ) -> Tensor { + let shape = S::shape(); + let strides = contiguous_strides(&shape); Tensor(Arc::new(Tensor_ { storage, + strides, + _ghost: PhantomData, + })) +} + +/// Create a Tensor from storage with explicit strides (for views/transposes). +fn from_storage_strided( + storage: Arc>, + strides: Vec, +) -> Tensor { + Tensor(Arc::new(Tensor_ { + storage, + strides, _ghost: PhantomData, })) } @@ -48,13 +67,16 @@ macro_rules! tensor_api { } impl Tensor, T, $device> { - /// Get data for a matrix. + /// Get data for a matrix, respecting strides (supports views/transposes). pub fn data(&self) -> Result>>> { let data = self.storage.to_cpu_storage()?; - let mut rows = Vec::new(); + let mut rows = Vec::with_capacity(A); for i in 0..A { - let row = (0..B).map(|j| data.as_ref().0[i * A + j]).collect(); - rows.push(row) + let base = i * self.strides[0]; + let row = (0..B) + .map(|j| data.as_ref().0[base + j * self.strides[1]]) + .collect(); + rows.push(row); } Ok(Cow::Owned(rows)) } @@ -63,15 +85,19 @@ macro_rules! tensor_api { impl Tensor, T, $device> { - /// Get data for a 3 dimensional tensor. + /// Get data for a 3 dimensional tensor, respecting strides (supports views/transposes). pub fn data(&self) -> Result>>>> { let data = self.storage.to_cpu_storage()?; - let mut top_rows = Vec::new(); + let mut top_rows = Vec::with_capacity(A); for i in 0..A { - let mut rows = Vec::new(); + let off_i = i * self.strides[0]; + let mut rows = Vec::with_capacity(B); for j in 0..B { - let row = (0..C).map(|k| data.as_ref().0[i * A + j * B + k]).collect(); - rows.push(row) + let off_j = off_i + j * self.strides[1]; + let row = (0..C) + .map(|k| data.as_ref().0[off_j + k * self.strides[2]]) + .collect(); + rows.push(row); } top_rows.push(rows); } @@ -94,19 +120,22 @@ impl Tensor { } } -/*macro_rules! binary_op { - ($trait:ident, $fn:ident) => { - impl $trait for Tensor { - type Output = Result>; - fn $fn(self, rhs: Self) -> Self::Output { - Ok(Self::fromTensor_(self.inner.$fn(&rhs.inner)?)) - } - } - }; +impl Tensor, T, D> { + /// Return a view of this matrix with dimensions transposed (A x B -> B x A). + pub fn t(&self) -> Tensor, T, D> { + // swap strides for first two dimensions + let mut new_strides = self.strides.clone(); + new_strides.swap(0, 1); + from_storage_strided::, T, D>(Arc::clone(&self.storage), new_strides) + } } -binary_op!(Add, add); -binary_op!(Mul, mul); -binary_op!(Sub, sub); -binary_op!(Div, div); -*/ +impl Tensor, T, D> { + /// Return a view of this tensor with last two reversed axes (A x B x C -> A x C x B). + pub fn t(&self) -> Tensor, T, D> { + // swap strides for last two dimensions + let mut new_strides = self.strides.clone(); + new_strides.swap(1, 2); + from_storage_strided::, T, D>(Arc::clone(&self.storage), new_strides) + } +} diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index 614a22f..fcfeba3 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -7,15 +7,18 @@ use std::{ use crate::{ device::Dev, graph::{BinaryOpType, Graph, GraphTensorId, Op, UnaryOpType}, - DType, Shape, R1, R3, + DType, Shape, R1, R2, R3, }; +use super::contiguous_strides; + /// A tensor representing an intermediary result of a graph. Performing operations /// on this tensor will not cause any computations. #[derive(Clone)] pub struct GraphTensor { id: GraphTensorId, graph: Arc>>, + strides: Vec, _ghost: PhantomData<(S, T, D)>, } @@ -28,20 +31,21 @@ impl self, rhs: GraphTensor, T, D>, ) -> GraphTensor, T, D> { - self.graph - .write() - .unwrap() - .add_op::>(Op::MatMul { + self.graph.write().unwrap().add_op::>( + Op::MatMul { l_id: self.id(), r_id: rhs.id(), o_id: None, k: K, alpha: T::ZERO, beta: T::ONE, - }); + }, + &self.strides, + ); GraphTensor { id: self.graph.write().unwrap().next_id(), graph: self.graph.clone(), + strides: self.strides.clone(), _ghost: PhantomData, } } @@ -56,20 +60,21 @@ impl alpha: T, beta: T, ) -> GraphTensor, T, D> { - self.graph - .write() - .unwrap() - .add_op::>(Op::MatMul { + self.graph.write().unwrap().add_op::>( + Op::MatMul { l_id: self.id(), r_id: rhs.id(), o_id: Some(out.id()), k: K, alpha, beta, - }); + }, + &self.strides, + ); GraphTensor { id: self.graph.write().unwrap().next_id(), graph: self.graph.clone(), + strides: self.strides.clone(), _ghost: PhantomData, } } @@ -80,10 +85,12 @@ impl GraphTensor { /// Create a tensor filled with some value. pub fn fill(graph: &mut Graph, v: T) -> Self { let id = graph.next_id(); - graph.add_op::(Op::Fill { v }); + let strides = contiguous_strides(&S::shape()); + graph.add_op::(Op::Fill { v }, &strides); Self { id, graph: Arc::new(RwLock::new(graph.clone())), + strides, _ghost: PhantomData, } } @@ -103,24 +110,31 @@ impl GraphTensor { #[must_use] /// Elementwise unary square root. pub fn sqrt(self) -> GraphTensor { - self.graph.write().unwrap().add_op::(Op::UnaryOp { - v_id: self.id(), - operator: UnaryOpType::Sqrt, - }); + self.graph.write().unwrap().add_op::( + Op::UnaryOp { + v_id: self.id(), + operator: UnaryOpType::Sqrt, + }, + &self.strides, + ); Self { id: self.graph.write().unwrap().next_id(), graph: self.graph.clone(), + strides: self.strides.clone(), _ghost: PhantomData, } } + #[must_use] /// Create a tensor filled with uniform random values in [0,1). pub fn rand(graph: &mut Graph) -> Self { let id = graph.next_id(); - graph.add_op::(Op::Rand); + let strides = contiguous_strides(&S::shape()); + graph.add_op::(Op::Rand, &strides); GraphTensor { id, graph: Arc::new(RwLock::new(graph.clone())), + strides, _ghost: PhantomData, } } @@ -129,10 +143,12 @@ impl GraphTensor { /// Create a tensor filled with normally distributed random values (mean, std). pub fn randn(graph: &mut Graph, mean: T, std: T) -> Self { let id = graph.next_id(); - graph.add_op::(Op::Randn { mean, std }); + let strides = contiguous_strides(&S::shape()); + graph.add_op::(Op::Randn { mean, std }, &strides); GraphTensor { id, graph: Arc::new(RwLock::new(graph.clone())), + strides, _ghost: PhantomData, } } @@ -156,14 +172,51 @@ impl GraphTensor, T, D> { pub fn arange(graph: &mut Graph, start: T, stop: T) -> Self { let id = graph.next_id(); let step = (stop.to_f64() - start.to_f64()) / (A as f64); - graph.add_op::>(Op::Arange { - start, - step: T::from_f64(step), - stop, - }); + let strides = contiguous_strides(&[A]); + graph.add_op::>( + Op::Arange { + start, + step: T::from_f64(step), + stop, + }, + &strides, + ); Self { id, graph: Arc::new(RwLock::new(graph.clone())), + strides, + _ghost: PhantomData, + } + } +} + +impl GraphTensor, T, D> { + /// Return a view of this matrix with dimensions transposed (A x B -> B x A). + pub fn t(&self) -> GraphTensor, T, D> { + // swap strides for first two dimensions + let mut new_strides = self.strides.clone(); + new_strides.swap(0, 1); + GraphTensor { + id: self.graph.write().unwrap().next_id(), + graph: self.graph.clone(), + strides: new_strides, + _ghost: PhantomData, + } + } +} + +impl + GraphTensor, T, D> +{ + /// Return a view of this tensor with last two reversed axes (A x B x C -> A x C x B). + pub fn t(&self) -> GraphTensor, T, D> { + // swap strides for last two dimensions + let mut new_strides = self.strides.clone(); + new_strides.swap(1, 2); + GraphTensor { + id: self.graph.write().unwrap().next_id(), + graph: self.graph.clone(), + strides: new_strides, _ghost: PhantomData, } } @@ -175,14 +228,18 @@ macro_rules! graphtensor_binop { type Output = GraphTensor; /// Add an elementwise operation to the graph. fn $fn_name(self, rhs: Self) -> Self::Output { - self.graph.write().unwrap().add_op::(Op::BinaryOp { - l_id: self.id(), - r_id: rhs.id(), - operator: BinaryOpType::$trait, - }); + self.graph.write().unwrap().add_op::( + Op::BinaryOp { + l_id: self.id(), + r_id: rhs.id(), + operator: BinaryOpType::$trait, + }, + &self.strides, + ); Self { id: self.graph.write().unwrap().next_id(), graph: self.graph.clone(), + strides: self.strides.clone(), _ghost: PhantomData, } } @@ -199,13 +256,17 @@ impl, D: Dev> Neg for GraphTensor type Output = GraphTensor; /// Add an elementwise addition operation to the graph. fn neg(self) -> Self::Output { - self.graph.write().unwrap().add_op::(Op::UnaryOp { - v_id: self.id(), - operator: UnaryOpType::Neg, - }); + self.graph.write().unwrap().add_op::( + Op::UnaryOp { + v_id: self.id(), + operator: UnaryOpType::Neg, + }, + &self.strides, + ); Self { id: self.graph.write().unwrap().next_id(), graph: self.graph.clone(), + strides: self.strides.clone(), _ghost: PhantomData, } } diff --git a/constensor-core/src/tensor/mod.rs b/constensor-core/src/tensor/mod.rs index 85d020a..378cb7c 100644 --- a/constensor-core/src/tensor/mod.rs +++ b/constensor-core/src/tensor/mod.rs @@ -3,3 +3,20 @@ pub mod graphtensor; pub use concretetensor::Tensor; pub use graphtensor::GraphTensor; + +pub(crate) fn is_contiguous_strides(strides: &[usize], shape: &[usize]) -> bool { + strides == &contiguous_strides(shape) +} + +/// Compute default (contiguous) strides for a tensor of given shape. +pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec { + let mut strides = Vec::with_capacity(shape.len()); + let mut acc = 1; + // Iterate dims in reverse to accumulate products + for dim in shape.iter().rev() { + strides.push(acc); + acc *= *dim; + } + strides.reverse(); + strides +} From 957e3343f5048fa024f942be4781255fff6959e1 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 14:10:01 -0400 Subject: [PATCH 02/11] Add stride for gemm, not simd yet --- constensor-core/examples/matmul/main.rs | 19 ++++---- constensor-core/src/cpu_storage/mod.rs | 55 ++++++++++++++++++----- constensor-core/src/dtype/gemm.rs | 31 ++++++++++--- constensor-core/src/graph.rs | 15 ++++++- constensor-core/src/tensor/graphtensor.rs | 2 +- constensor-core/src/tensor/mod.rs | 4 -- 6 files changed, 94 insertions(+), 32 deletions(-) diff --git a/constensor-core/examples/matmul/main.rs b/constensor-core/examples/matmul/main.rs index 8dceab2..e7ec331 100644 --- a/constensor-core/examples/matmul/main.rs +++ b/constensor-core/examples/matmul/main.rs @@ -7,13 +7,13 @@ fn bench, T, Cpu>::ones(&mut graph); - let b = GraphTensor::, T, Cpu>::ones(&mut graph); - let o = GraphTensor::, T, Cpu>::ones(&mut graph); + let a = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(1.)); + let b = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(2.)).t(); + let o = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(3.)); let _c = a.matmul_axpby(b, o, alpha, beta); graph.optimize(); @@ -22,7 +22,8 @@ fn bench("f32", 1.0, 1.0); - bench::("i32", 1, 1); + // bench::("i32", 1, 1); } diff --git a/constensor-core/src/cpu_storage/mod.rs b/constensor-core/src/cpu_storage/mod.rs index 97ba700..df6ae58 100644 --- a/constensor-core/src/cpu_storage/mod.rs +++ b/constensor-core/src/cpu_storage/mod.rs @@ -9,7 +9,7 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelI use crate::device::Dev; use crate::storage::Storage; -use crate::tensor::is_contiguous_strides; +use crate::tensor::contiguous_strides; use crate::Shape; use crate::{ storage::{BackendDevice, BackendStorage}, @@ -63,15 +63,21 @@ impl BackendDevice for CpuDevice { dep_graph.add_edge(b_id.get(), idx, ()); dep_graph.add_edge(c_id.get(), idx, ()); } - Op::MatMul { l_id, r_id, .. } => { + Op::MatMul { + l_id, r_id, o_id, .. + } => { dep_graph.add_edge(l_id.get(), idx, ()); dep_graph.add_edge(r_id.get(), idx, ()); + if let Some(o_id) = o_id { + dep_graph.add_edge(o_id.get(), idx, ()); + } } // NoOp, Fill/Arange, Rand/Randn don’t create incoming edges Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {} } } + dbg!(&dep_graph); // Compute topological order let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); @@ -103,6 +109,8 @@ impl BackendDevice for CpuDevice { // Prepare storage for intermediate results let mut results: Vec>> = Vec::with_capacity(graph.len()); results.resize_with(graph.len(), || None); + let mut results_strides: Vec>> = Vec::with_capacity(graph.len()); + results_strides.resize_with(graph.len(), || None); let mut rng = rng(); @@ -113,9 +121,6 @@ impl BackendDevice for CpuDevice { let out_shape = &op.shape; let out_elem_count: usize = out_shape.iter().product(); - let strides = &op.strides; - let is_contiguous = is_contiguous_strides(strides, out_shape); - let computed = match &op.op { Op::BinaryOp { l_id, @@ -228,27 +233,57 @@ impl BackendDevice for CpuDevice { let m = shape[1]; let n = shape[2]; - let mut out = if let Some(o_id) = o_id { + let (mut out, out_stride) = if let Some(o_id) = o_id { if o_id.is_inplace() { - results[o_id.get()].take().unwrap() + let out_strides = results_strides[o_id.get()].as_ref().unwrap(); + (results[o_id.get()].take().unwrap(), out_strides.clone()) } else { let o_buf = results[o_id.get()].as_ref().unwrap(); - PooledBuffer::new((*o_buf).clone(), pool.clone()) + let out_strides = results_strides[o_id.get()].as_ref().unwrap(); + ( + PooledBuffer::new((*o_buf).clone(), pool.clone()), + out_strides.clone(), + ) } } else { - PooledBuffer::new(pool.borrow_mut().get_buffer(m * n), pool.clone()) + ( + PooledBuffer::new( + pool.borrow_mut().get_buffer(b * m * n), + pool.clone(), + ), + contiguous_strides(&[b, m, n]), + ) }; let a_buf = results[l_id.get()].as_ref().unwrap(); let b_buf = results[r_id.get()].as_ref().unwrap(); - T::launch_gemm(a_buf, b_buf, b, m, n, *k, &mut out, *alpha, *beta); + let a_strides = results_strides[l_id.get()].as_ref().unwrap(); + let b_strides = results_strides[r_id.get()].as_ref().unwrap(); + + T::launch_gemm( + a_buf, + a_strides, + b_buf, + b_strides, + b, + m, + n, + *k, + &mut out, + &out_stride, + *alpha, + *beta, + ); out } Op::NoOp => unreachable!("NoOp should not be evaluated."), }; + results[idx] = Some(computed); + dbg!(&op.strides); + results_strides[idx] = Some(op.strides.clone()); } // Extract final result diff --git a/constensor-core/src/dtype/gemm.rs b/constensor-core/src/dtype/gemm.rs index b8870fb..44b7e86 100644 --- a/constensor-core/src/dtype/gemm.rs +++ b/constensor-core/src/dtype/gemm.rs @@ -13,12 +13,15 @@ pub trait GemmDispatch { // Matrix multiplication: (B x M x K) * (B x K x N) = (B x M x N) fn launch_gemm( lhs: &[Self], + lhs_stride: &[usize], rhs: &[Self], + rhs_stride: &[usize], b: usize, m: usize, n: usize, k: usize, out: &mut Vec, + out_stride: &[usize], alpha: Self, beta: Self, ) where @@ -151,12 +154,15 @@ macro_rules! instantiate_gemm { impl GemmDispatch for $rt { fn launch_gemm( lhs: &[Self], + lhs_stride: &[usize], rhs: &[Self], + rhs_stride: &[usize], b: usize, m: usize, n: usize, k: usize, out: &mut Vec, + out_stride: &[usize], alpha: Self, beta: Self, ) where @@ -169,15 +175,22 @@ macro_rules! instantiate_gemm { Parallelism::None }; + debug_assert_eq!(lhs.len(), b * m * k); + debug_assert_eq!(lhs_stride.len(), 3); + debug_assert_eq!(rhs.len(), b * k * n); + debug_assert_eq!(rhs_stride.len(), 3); + debug_assert_eq!(out.len(), b * m * n); + debug_assert_eq!(out_stride.len(), 3); + // cs = stride[-1], rs = stride[-2] - let dst_cs = 1; - let dst_rs = n; + let dst_cs = out_stride[2]; + let dst_rs = out_stride[1]; - let lhs_cs = 1; - let lhs_rs = k; + let lhs_cs = lhs_stride[2]; + let lhs_rs = lhs_stride[1]; - let rhs_cs = 1; - let rhs_rs = n; + let rhs_cs = rhs_stride[2]; + let rhs_rs = rhs_stride[1]; let read_dst = alpha != $zero; @@ -220,12 +233,15 @@ macro_rules! instantiate_gemm { impl GemmDispatch for $rt { fn launch_gemm( lhs: &[Self], + lhs_stride: &[usize], rhs: &[Self], + rhs_stride: &[usize], b: usize, m: usize, n: usize, k: usize, out: &mut Vec, + out_stride: &[usize], alpha: Self, beta: Self, ) where @@ -238,8 +254,11 @@ macro_rules! instantiate_gemm { let rem = n % BLOCK_SIZE; debug_assert_eq!(lhs.len(), b * m * k); + debug_assert_eq!(lhs_stride.len(), 3); debug_assert_eq!(rhs.len(), b * k * n); + debug_assert_eq!(rhs_stride.len(), 3); debug_assert_eq!(out.len(), b * m * n); + debug_assert_eq!(out_stride.len(), 3); for batch in 0..b { // Compute base pointers once per batch diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index e692452..fd0ef49 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -315,7 +315,12 @@ impl Graph { } => { vec![a_id, b_id, c_id] } - Op::MatMul { l_id, r_id, .. } => vec![l_id, r_id], + Op::MatMul { + l_id, r_id, o_id, .. + } => o_id + .as_ref() + .map(|o| vec![l_id, r_id, o]) + .unwrap_or(vec![l_id, r_id]), Op::NoOp => vec![], }; @@ -365,9 +370,14 @@ impl Graph { *usage.entry(b_id.clone()).or_default() += 1; *usage.entry(c_id.clone()).or_default() += 1; } - Op::MatMul { l_id, r_id, .. } => { + Op::MatMul { + l_id, r_id, o_id, .. + } => { *usage.entry(l_id.clone()).or_default() += 1; *usage.entry(r_id.clone()).or_default() += 1; + if let Some(o_id) = o_id { + *usage.entry(o_id.clone()).or_default() += 1; + } } // No input usage for these ops Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {} @@ -588,6 +598,7 @@ impl Graph { /// - Inplace matrix-multiplication when safe /// - Dead code removal pub fn optimize(&mut self) { + return; // Constant folding first self.optimize_const(); // Fuse mul-add into FMA diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index fcfeba3..be9e3fb 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -209,7 +209,7 @@ impl GraphTensor, T, D> { /// Return a view of this tensor with last two reversed axes (A x B x C -> A x C x B). - pub fn t(&self) -> GraphTensor, T, D> { + pub fn t(&self) -> GraphTensor, T, D> { // swap strides for last two dimensions let mut new_strides = self.strides.clone(); new_strides.swap(1, 2); diff --git a/constensor-core/src/tensor/mod.rs b/constensor-core/src/tensor/mod.rs index 378cb7c..4ce3019 100644 --- a/constensor-core/src/tensor/mod.rs +++ b/constensor-core/src/tensor/mod.rs @@ -4,10 +4,6 @@ pub mod graphtensor; pub use concretetensor::Tensor; pub use graphtensor::GraphTensor; -pub(crate) fn is_contiguous_strides(strides: &[usize], shape: &[usize]) -> bool { - strides == &contiguous_strides(shape) -} - /// Compute default (contiguous) strides for a tensor of given shape. pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec { let mut strides = Vec::with_capacity(shape.len()); From 7fa3f83fa307a7810db3351d4198b36dd8eda97f Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 14:44:40 -0400 Subject: [PATCH 03/11] Add a permute node --- constensor-core/examples/matmul/main.rs | 1 + constensor-core/src/cpu_storage/mod.rs | 20 ++++++--- constensor-core/src/graph.rs | 32 ++++++++++++-- constensor-core/src/tensor/graphtensor.rs | 51 ++++++++++++++++++----- 4 files changed, 86 insertions(+), 18 deletions(-) diff --git a/constensor-core/examples/matmul/main.rs b/constensor-core/examples/matmul/main.rs index e7ec331..25d218f 100644 --- a/constensor-core/examples/matmul/main.rs +++ b/constensor-core/examples/matmul/main.rs @@ -13,6 +13,7 @@ fn bench, T, Cpu>::fill(&mut graph, T::from_f64(1.)); let b = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(2.)).t(); + // let b = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(2.)); let o = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(3.)); let _c = a.matmul_axpby(b, o, alpha, beta); diff --git a/constensor-core/src/cpu_storage/mod.rs b/constensor-core/src/cpu_storage/mod.rs index df6ae58..7a87303 100644 --- a/constensor-core/src/cpu_storage/mod.rs +++ b/constensor-core/src/cpu_storage/mod.rs @@ -45,11 +45,12 @@ impl BackendDevice for CpuDevice { ) -> Result> { // Build a dependency graph of tensor indices let mut dep_graph = DiGraphMap::::new(); - for idx in 0..graph.len() { - dep_graph.add_node(idx); + for id in graph.iter().map(|node| node.id.get()) { + dep_graph.add_node(id); } - for (idx, node) in graph.iter().enumerate() { + for node in graph.iter() { + let idx = node.id.get(); match &node.op { Op::BinaryOp { l_id, r_id, .. } => { dep_graph.add_edge(l_id.get(), idx, ()); @@ -72,12 +73,14 @@ impl BackendDevice for CpuDevice { dep_graph.add_edge(o_id.get(), idx, ()); } } + Op::Permute { v_id } => { + dep_graph.add_edge(v_id.get(), idx, ()); + } // NoOp, Fill/Arange, Rand/Randn don’t create incoming edges Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {} } } - dbg!(&dep_graph); // Compute topological order let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); @@ -279,10 +282,17 @@ impl BackendDevice for CpuDevice { out } Op::NoOp => unreachable!("NoOp should not be evaluated."), + Op::Permute { v_id } => { + if v_id.is_inplace() { + results[v_id.get()].take().unwrap() + } else { + let buf = results[v_id.get()].as_ref().unwrap(); + PooledBuffer::new((*buf).clone(), pool.clone()) + } + } }; results[idx] = Some(computed); - dbg!(&op.strides); results_strides[idx] = Some(op.strides.clone()); } diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index fd0ef49..fba7e9b 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -22,6 +22,7 @@ pub struct GraphNode { pub op: Op, pub shape: Vec, pub strides: Vec, + pub id: GraphTensorId, } #[derive(Clone)] @@ -45,11 +46,12 @@ impl Graph { } /// Append an operation to the graph - pub(crate) fn add_op(&self, op: Op, strides: &[usize]) { + pub(crate) fn add_op(&self, op: Op, strides: &[usize], id: &GraphTensorId) { self.data.write().unwrap().push(GraphNode { op, shape: S::shape(), strides: strides.to_vec(), + id: id.clone(), }); } @@ -90,6 +92,7 @@ impl Graph { Op::FusedMulAdd { .. } => "FMA".to_string(), // Matrix multiplication Op::MatMul { .. } => "MatMul".to_string(), + Op::Permute { v_id: _ } => "Permute".to_string(), // we already matched NoOp above Op::NoOp => unreachable!(), }; @@ -172,6 +175,15 @@ impl Graph { } } } + Op::Permute { v_id, .. } => { + if let Some(src) = idx_map[v_id.get()] { + let mut label = "v".to_string(); + if v_id.is_inplace() { + label.push('*'); + } + g.add_edge(src, dst, label.clone()); + } + } // NoOp, Fill/Arange, Rand/Randn don’t create incoming edges Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {} } @@ -321,6 +333,7 @@ impl Graph { .as_ref() .map(|o| vec![l_id, r_id, o]) .unwrap_or(vec![l_id, r_id]), + Op::Permute { v_id } => vec![v_id], Op::NoOp => vec![], }; @@ -379,6 +392,9 @@ impl Graph { *usage.entry(o_id.clone()).or_default() += 1; } } + Op::Permute { v_id } => { + *usage.entry(v_id.clone()).or_default() += 1; + } // No input usage for these ops Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {} } @@ -532,7 +548,14 @@ impl Graph { keep[o_id.get()] = true; } } - _ => {} + Op::Permute { v_id, .. } => { + keep[v_id.get()] = true; + } + Op::NoOp + | Op::Fill { .. } + | Op::Arange { .. } + | Op::Rand + | Op::Randn { .. } => (), } } } @@ -598,7 +621,6 @@ impl Graph { /// - Inplace matrix-multiplication when safe /// - Dead code removal pub fn optimize(&mut self) { - return; // Constant folding first self.optimize_const(); // Fuse mul-add into FMA @@ -748,6 +770,10 @@ pub enum Op { mean: T, std: T, }, + // Permutation operator. + Permute { + v_id: GraphTensorId, + }, NoOp, } diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index be9e3fb..603f40f 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -31,6 +31,7 @@ impl self, rhs: GraphTensor, T, D>, ) -> GraphTensor, T, D> { + let id = self.graph.write().unwrap().next_id(); self.graph.write().unwrap().add_op::>( Op::MatMul { l_id: self.id(), @@ -41,9 +42,10 @@ impl beta: T::ONE, }, &self.strides, + &id, ); GraphTensor { - id: self.graph.write().unwrap().next_id(), + id, graph: self.graph.clone(), strides: self.strides.clone(), _ghost: PhantomData, @@ -60,6 +62,7 @@ impl alpha: T, beta: T, ) -> GraphTensor, T, D> { + let id = self.graph.write().unwrap().next_id(); self.graph.write().unwrap().add_op::>( Op::MatMul { l_id: self.id(), @@ -70,9 +73,10 @@ impl beta, }, &self.strides, + &id, ); GraphTensor { - id: self.graph.write().unwrap().next_id(), + id, graph: self.graph.clone(), strides: self.strides.clone(), _ghost: PhantomData, @@ -86,7 +90,7 @@ impl GraphTensor { pub fn fill(graph: &mut Graph, v: T) -> Self { let id = graph.next_id(); let strides = contiguous_strides(&S::shape()); - graph.add_op::(Op::Fill { v }, &strides); + graph.add_op::(Op::Fill { v }, &strides, &id); Self { id, graph: Arc::new(RwLock::new(graph.clone())), @@ -110,15 +114,17 @@ impl GraphTensor { #[must_use] /// Elementwise unary square root. pub fn sqrt(self) -> GraphTensor { + let id = self.graph.write().unwrap().next_id(); self.graph.write().unwrap().add_op::( Op::UnaryOp { v_id: self.id(), operator: UnaryOpType::Sqrt, }, &self.strides, + &id, ); Self { - id: self.graph.write().unwrap().next_id(), + id, graph: self.graph.clone(), strides: self.strides.clone(), _ghost: PhantomData, @@ -130,7 +136,7 @@ impl GraphTensor { pub fn rand(graph: &mut Graph) -> Self { let id = graph.next_id(); let strides = contiguous_strides(&S::shape()); - graph.add_op::(Op::Rand, &strides); + graph.add_op::(Op::Rand, &strides, &id); GraphTensor { id, graph: Arc::new(RwLock::new(graph.clone())), @@ -144,7 +150,7 @@ impl GraphTensor { pub fn randn(graph: &mut Graph, mean: T, std: T) -> Self { let id = graph.next_id(); let strides = contiguous_strides(&S::shape()); - graph.add_op::(Op::Randn { mean, std }, &strides); + graph.add_op::(Op::Randn { mean, std }, &strides, &id); GraphTensor { id, graph: Arc::new(RwLock::new(graph.clone())), @@ -180,6 +186,7 @@ impl GraphTensor, T, D> { stop, }, &strides, + &id, ); Self { id, @@ -196,8 +203,18 @@ impl GraphTensor, T, // swap strides for first two dimensions let mut new_strides = self.strides.clone(); new_strides.swap(0, 1); + + let id = self.graph.write().unwrap().next_id(); + + self.graph.write().unwrap().add_op::>( + Op::Permute { + v_id: self.id.clone(), + }, + &new_strides, + &id, + ); GraphTensor { - id: self.graph.write().unwrap().next_id(), + id, graph: self.graph.clone(), strides: new_strides, _ghost: PhantomData, @@ -213,8 +230,18 @@ impl // swap strides for last two dimensions let mut new_strides = self.strides.clone(); new_strides.swap(1, 2); + + let id = self.graph.write().unwrap().next_id(); + + self.graph.write().unwrap().add_op::>( + Op::Permute { + v_id: self.id.clone(), + }, + &new_strides, + &id, + ); GraphTensor { - id: self.graph.write().unwrap().next_id(), + id, graph: self.graph.clone(), strides: new_strides, _ghost: PhantomData, @@ -228,6 +255,7 @@ macro_rules! graphtensor_binop { type Output = GraphTensor; /// Add an elementwise operation to the graph. fn $fn_name(self, rhs: Self) -> Self::Output { + let id = self.graph.write().unwrap().next_id(); self.graph.write().unwrap().add_op::( Op::BinaryOp { l_id: self.id(), @@ -235,9 +263,10 @@ macro_rules! graphtensor_binop { operator: BinaryOpType::$trait, }, &self.strides, + &id, ); Self { - id: self.graph.write().unwrap().next_id(), + id, graph: self.graph.clone(), strides: self.strides.clone(), _ghost: PhantomData, @@ -256,15 +285,17 @@ impl, D: Dev> Neg for GraphTensor type Output = GraphTensor; /// Add an elementwise addition operation to the graph. fn neg(self) -> Self::Output { + let id = self.graph.write().unwrap().next_id(); self.graph.write().unwrap().add_op::( Op::UnaryOp { v_id: self.id(), operator: UnaryOpType::Neg, }, &self.strides, + &id, ); Self { - id: self.graph.write().unwrap().next_id(), + id, graph: self.graph.clone(), strides: self.strides.clone(), _ghost: PhantomData, From 7d81789ba0f85d8347b9f4aed3773c8e3c8744e5 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 14:47:27 -0400 Subject: [PATCH 04/11] Strided simd gemm --- constensor-core/src/dtype/gemm.rs | 33 +++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/constensor-core/src/dtype/gemm.rs b/constensor-core/src/dtype/gemm.rs index 44b7e86..f2c3a95 100644 --- a/constensor-core/src/dtype/gemm.rs +++ b/constensor-core/src/dtype/gemm.rs @@ -253,6 +253,18 @@ macro_rules! instantiate_gemm { let n_blocks = n / BLOCK_SIZE; let rem = n % BLOCK_SIZE; + let lhs_bs = lhs_stride[0]; + let lhs_rs = lhs_stride[1]; + let lhs_cs = lhs_stride[2]; + + let rhs_bs = rhs_stride[0]; + let rhs_rs = rhs_stride[1]; + let rhs_cs = rhs_stride[2]; + + let out_bs = out_stride[0]; + let out_rs = out_stride[1]; + let out_cs = out_stride[2]; + debug_assert_eq!(lhs.len(), b * m * k); debug_assert_eq!(lhs_stride.len(), 3); debug_assert_eq!(rhs.len(), b * k * n); @@ -262,18 +274,18 @@ macro_rules! instantiate_gemm { for batch in 0..b { // Compute base pointers once per batch - let lhs_base = unsafe { lhs.as_ptr().add(batch * m * k) }; - let rhs_base = unsafe { rhs.as_ptr().add(batch * k * n) }; - let out_base = unsafe { out.as_mut_ptr().add(batch * m * n) }; + let lhs_base = unsafe { lhs.as_ptr().add(batch * lhs_bs) }; + let rhs_base = unsafe { rhs.as_ptr().add(batch * rhs_bs) }; + let out_base = unsafe { out.as_mut_ptr().add(batch * out_bs) }; for i in 0..m { // Pointer to the start of the current output row - let out_row_ptr = unsafe { out_base.add(i * n) }; + let out_row_ptr = unsafe { out_base.add(i * out_rs) }; // Process full SIMD blocks for block in 0..n_blocks { let off = block * BLOCK_SIZE; - let out_ptr = unsafe { out_row_ptr.add(off) }; + let out_ptr = unsafe { out_row_ptr.add(off * out_cs) }; let out_chunk = unsafe { std::slice::from_raw_parts_mut(out_ptr, BLOCK_SIZE) }; @@ -291,9 +303,9 @@ macro_rules! instantiate_gemm { } for p in 0..k { - let a_val = unsafe { *lhs_base.add(i * k + p) }; + let a_val = unsafe { *lhs_base.add(i * lhs_rs + p * lhs_cs) }; let a_arr = [a_val; BLOCK_SIZE]; - let b_ptr = unsafe { rhs_base.add(p * n + off) }; + let b_ptr = unsafe { rhs_base.add(p * rhs_rs + off * rhs_cs) }; let b_chunk = unsafe { std::slice::from_raw_parts(b_ptr, BLOCK_SIZE) }; ::fma_op_inplace_c( @@ -305,7 +317,7 @@ macro_rules! instantiate_gemm { // Handle remainder elements if rem > 0 { let off = n_blocks * BLOCK_SIZE; - let out_ptr = unsafe { out_row_ptr.add(off) }; + let out_ptr = unsafe { out_row_ptr.add(off * out_cs) }; let out_chunk = unsafe { std::slice::from_raw_parts_mut(out_ptr, rem) }; if beta != $init { @@ -319,9 +331,10 @@ macro_rules! instantiate_gemm { } for p in 0..k { - let a_val = unsafe { *lhs_base.add(i * k + p) }; + let a_val = unsafe { *lhs_base.add(i * lhs_rs + p * lhs_cs) }; for j in 0..rem { - let b_val = unsafe { *rhs_base.add(p * n + off + j) }; + let b_val = + unsafe { *rhs_base.add(p * rhs_rs + (off + j) * rhs_cs) }; out_chunk[j] += a_val * b_val; } } From 272f8c9480a0a9f49a46755c5770fe0a24ac1001 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 14:48:23 -0400 Subject: [PATCH 05/11] Strided naive gemm --- constensor-core/src/dtype/gemm.rs | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/constensor-core/src/dtype/gemm.rs b/constensor-core/src/dtype/gemm.rs index f2c3a95..d0fbb90 100644 --- a/constensor-core/src/dtype/gemm.rs +++ b/constensor-core/src/dtype/gemm.rs @@ -121,26 +121,43 @@ macro_rules! instantiate_gemm { impl GemmDispatch for $rt { fn launch_gemm( lhs: &[Self], + lhs_stride: &[usize], rhs: &[Self], + rhs_stride: &[usize], b: usize, m: usize, n: usize, k: usize, out: &mut Vec, + out_stride: &[usize], alpha: Self, beta: Self, ) where Self: Sized, { - for b in 0..b { + let lhs_bs = lhs_stride[0]; + let lhs_rs = lhs_stride[1]; + let lhs_cs = lhs_stride[2]; + + let rhs_bs = rhs_stride[0]; + let rhs_rs = rhs_stride[1]; + let rhs_cs = rhs_stride[2]; + + let out_bs = out_stride[0]; + let out_rs = out_stride[1]; + let out_cs = out_stride[2]; + + for batch_idx in 0..b { for i in 0..m { for j in 0..n { let mut sum = $init; for p in 0..k { - sum += - beta * lhs[b * m * k + i * k + p] * rhs[b * k * n + p * n + j]; + let lhs_val = lhs[batch_idx * lhs_bs + i * lhs_rs + p * lhs_cs]; + let rhs_val = rhs[batch_idx * rhs_bs + p * rhs_rs + j * rhs_cs]; + sum += beta * lhs_val * rhs_val; } - out[b * m * n + i * n + j] = alpha * out[b * m * n + i * n + j] + sum; + let out_idx = batch_idx * out_bs + i * out_rs + j * out_cs; + out[out_idx] = alpha * out[out_idx] + sum; } } } From b70c63b1bcdaed87e1950a0b161a34a2183c26cb Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 14:49:43 -0400 Subject: [PATCH 06/11] Fix unused import --- constensor-core/src/dtype/rand.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/constensor-core/src/dtype/rand.rs b/constensor-core/src/dtype/rand.rs index 605b70c..4603786 100644 --- a/constensor-core/src/dtype/rand.rs +++ b/constensor-core/src/dtype/rand.rs @@ -4,11 +4,6 @@ use { crate::{cuda_backend::error::WrapErr, Result}, cudarc::{curand::CudaRng, driver::CudaSlice}, }; -// Optional half-precision types -#[cfg(feature = "bfloat")] -use half::bf16; -#[cfg(feature = "half")] -use half::f16; /// Dispatch random fills based on the data type (CUDA backend). pub trait RandDispatch { From 07a88e3c3c9b95ad57fff84196722dd0247a7795 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 18:54:56 +0000 Subject: [PATCH 07/11] Cuda compiles, not implemented yet --- constensor-core/src/cuda_backend/mod.rs | 16 ++++++++++++++-- constensor-core/src/dtype/rand.rs | 4 ++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index cae4205..8b7100b 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -304,6 +304,10 @@ fn handle_node( format!("( static_cast(fma(static_cast({a_name}), static_cast({b_name}), static_cast({c_name}))))") } Op::NoOp => unreachable!("no-op ops should never be reached."), + Op::Permute { v_id } => { + let name = handle_node(current_name, header, &graph[v_id.get()], graph); + format!("({})", name) + } Op::MatMul { .. } | Op::Rand | Op::Randn { .. } => { unreachable!("op should have its own split!") } @@ -468,7 +472,7 @@ impl BackendDevice for CudaDevice { ) -> Result> { // Build a dependency graph of tensor indices let mut dep_graph = DiGraphMap::::new(); - for idx in 0..graph.len() { + for idx in graph.iter().map(|node| node.id.get()) { dep_graph.add_node(idx); } @@ -486,9 +490,17 @@ impl BackendDevice for CudaDevice { dep_graph.add_edge(b_id.get(), idx, ()); dep_graph.add_edge(c_id.get(), idx, ()); } - Op::MatMul { l_id, r_id, .. } => { + Op::MatMul { + l_id, r_id, o_id, .. + } => { dep_graph.add_edge(l_id.get(), idx, ()); dep_graph.add_edge(r_id.get(), idx, ()); + if let Some(o_id) = o_id { + dep_graph.add_edge(o_id.get(), idx, ()); + } + } + Op::Permute { v_id } => { + dep_graph.add_edge(v_id.get(), idx, ()); } // These don’t create incoming edges Op::NoOp | Op::Fill { .. } | Op::Rand | Op::Randn { .. } | Op::Arange { .. } => {} diff --git a/constensor-core/src/dtype/rand.rs b/constensor-core/src/dtype/rand.rs index 4603786..8efd38d 100644 --- a/constensor-core/src/dtype/rand.rs +++ b/constensor-core/src/dtype/rand.rs @@ -138,7 +138,7 @@ impl RandDispatch for i64 { } } #[cfg(all(feature = "cuda", feature = "half"))] -impl RandDispatch for f16 { +impl RandDispatch for half::f16 { fn cuda_fill_with_uniform(_rng: &CudaRng, _slice: &mut CudaSlice) -> Result<()> { crate::bail!( "Uniform random fill is not supported for dtype {}", @@ -158,7 +158,7 @@ impl RandDispatch for f16 { } } #[cfg(all(feature = "cuda", feature = "bfloat"))] -impl RandDispatch for bf16 { +impl RandDispatch for half::bf16 { fn cuda_fill_with_uniform(_rng: &CudaRng, _slice: &mut CudaSlice) -> Result<()> { crate::bail!( "Uniform random fill is not supported for dtype {}", From 6672a9067efeb4305152ba8635c3caa310c361d0 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 19:04:26 +0000 Subject: [PATCH 08/11] Support strided cublas gemm --- constensor-core/src/cuda_backend/mod.rs | 23 +++++++++++-- constensor-core/src/cuda_backend/util.rs | 44 +++++++++++++++++------- constensor-core/src/dtype/gemm.rs | 18 +++++++++- constensor-core/src/error.rs | 7 ++-- 4 files changed, 73 insertions(+), 19 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 8b7100b..16448d6 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -25,6 +25,7 @@ use crate::{ cpu_storage::CpuStorage, device::Dev, storage::{BackendDevice, BackendStorage, Storage}, + tensor::contiguous_strides, CompiledGraph, DType, GraphNode, Op, Result, Shape, }; @@ -211,6 +212,9 @@ pub enum CudaCompiledKernel { r_id: usize, /// Optional output tensor ID for axpby semantics o_id: Option, + l_stride: Vec, + r_stride: Vec, + o_stride: Option>, b: usize, m: usize, n: usize, @@ -541,8 +545,12 @@ impl BackendDevice for CudaDevice { } => { let l_shape = &graph[l_id.get()].shape; let r_shape = &graph[r_id.get()].shape; + let l_stride = &graph[l_id.get()].strides; + let r_stride = &graph[r_id.get()].strides; assert_eq!(l_shape.len(), 3); assert_eq!(r_shape.len(), 3); + assert_eq!(l_stride.len(), 3); + assert_eq!(r_stride.len(), 3); let (b, m, _k) = (l_shape[0], l_shape[1], l_shape[2]); let n = r_shape[2]; @@ -554,6 +562,9 @@ impl BackendDevice for CudaDevice { l_id: l_id.get(), r_id: r_id.get(), o_id: o_id.as_ref().map(|id| id.get()), + l_stride: l_stride.clone(), + r_stride: r_stride.clone(), + o_stride: o_id.as_ref().map(|id| graph[id.get()].strides.clone()), b, m, n, @@ -666,6 +677,9 @@ impl BackendDevice for CudaDevice { l_id, r_id, o_id, + l_stride, + r_stride, + o_stride, b, m, n, @@ -684,7 +698,7 @@ impl BackendDevice for CudaDevice { lhs.event.synchronize().w()?; rhs.event.synchronize().w()?; - let elems = m * n; + let elems = b * m * n; // prepare output buffer, copy initial if provided let mut out = unsafe { stream.alloc::(elems) }.w()?; if let Some(o_idx) = o_id { @@ -694,9 +708,14 @@ impl BackendDevice for CudaDevice { self.stream().memcpy_dtod(&init.slice, &mut out).w()?; } + let o_stride = o_stride + .clone() + .unwrap_or(contiguous_strides(&[*b, *m, *n])); + // Launch GEMM on the pooled stream T::launch_gemm_cuda( - cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha, + cublas, &lhs.slice, &rhs.slice, l_stride, r_stride, *b, *m, *n, *k, + &mut out, &o_stride, *beta, *alpha, )?; // Record completion event for the MatMul result diff --git a/constensor-core/src/cuda_backend/util.rs b/constensor-core/src/cuda_backend/util.rs index 9ac05ff..73f26a7 100644 --- a/constensor-core/src/cuda_backend/util.rs +++ b/constensor-core/src/cuda_backend/util.rs @@ -6,12 +6,13 @@ pub(crate) fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + rhs_stride: &[usize], + out_stride: &[usize], ) -> Result> { // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; - let lhs_stride = [m * k, k, 1]; - let rhs_stride = [k * n, n, 1]; let lhs_dims = [b, m, k]; let rhs_dims = [b, k, n]; @@ -28,8 +29,9 @@ pub(crate) fn gemm_config( (k as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(Error::MatMulNonContiguous { - lhs_stride, - rhs_stride, + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + out_stride: out_stride.to_vec(), mnk: (m, n, k), })? }; @@ -42,8 +44,9 @@ pub(crate) fn gemm_config( (m as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(Error::MatMulNonContiguous { - lhs_stride, - rhs_stride, + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + out_stride: out_stride.to_vec(), mnk: (m, n, k), })? }; @@ -62,27 +65,42 @@ pub(crate) fn gemm_config( transb, }; - let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { + let stride_a: usize = match lhs_stride[..lhs_stride.len() - 2] { [s1, stride] if s1 == stride * lhs_dims[1] => stride, [_, stride] if lhs_dims[0] == 1 => stride, [stride, _] if lhs_dims[1] == 1 => stride, [stride] => stride, [] => m * k, _ => Err(Error::MatMulNonContiguous { - lhs_stride, - rhs_stride, + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + out_stride: out_stride.to_vec(), mnk: (m, n, k), })?, }; - let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { + let stride_b: usize = match rhs_stride[..rhs_stride.len() - 2] { [s1, stride] if s1 == stride * rhs_dims[1] => stride, [_, stride] if rhs_dims[0] == 1 => stride, [stride, _] if rhs_dims[1] == 1 => stride, [stride] => stride, [] => n * k, _ => Err(Error::MatMulNonContiguous { - lhs_stride, - rhs_stride, + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + out_stride: out_stride.to_vec(), + mnk: (m, n, k), + })?, + }; + let stride_c: usize = match out_stride[..out_stride.len() - 2] { + [s1, stride] if s1 == stride * rhs_dims[1] => stride, + [_, stride] if rhs_dims[0] == 1 => stride, + [stride, _] if rhs_dims[1] == 1 => stride, + [stride] => stride, + [] => m * n, + _ => Err(Error::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + out_stride: out_stride.to_vec(), mnk: (m, n, k), })?, }; @@ -92,6 +110,6 @@ pub(crate) fn gemm_config( gemm, stride_a: stride_a as i64, stride_b: stride_b as i64, - stride_c: (m * n) as i64, + stride_c: stride_c as i64, }) } diff --git a/constensor-core/src/dtype/gemm.rs b/constensor-core/src/dtype/gemm.rs index d0fbb90..9ba49cb 100644 --- a/constensor-core/src/dtype/gemm.rs +++ b/constensor-core/src/dtype/gemm.rs @@ -34,11 +34,14 @@ pub trait GemmDispatch { cublas: &cudarc::cublas::CudaBlas, lhs: &cudarc::driver::CudaSlice, rhs: &cudarc::driver::CudaSlice, + lhs_stride: &[usize], + rhs_stride: &[usize], b: usize, m: usize, n: usize, k: usize, out: &mut cudarc::driver::CudaSlice, + out_stride: &[usize], alpha: Self, beta: Self, ) -> crate::Result<()> @@ -66,11 +69,14 @@ macro_rules! instantiate_gemm_cuda { _cublas: &cudarc::cublas::CudaBlas, _lhs: &cudarc::driver::CudaSlice, _rhs: &cudarc::driver::CudaSlice, + _lhs_stride: &[usize], + _rhs_stride: &[usize], _b: usize, _m: usize, _n: usize, _k: usize, _out: &mut cudarc::driver::CudaSlice, + _out_stride: &[usize], _alpha: Self, _beta: Self, ) -> crate::Result<()> @@ -87,18 +93,28 @@ macro_rules! instantiate_gemm_cuda { cublas: &cudarc::cublas::CudaBlas, lhs: &cudarc::driver::CudaSlice<$rt>, rhs: &cudarc::driver::CudaSlice<$rt>, + lhs_stride: &[usize], + rhs_stride: &[usize], b: usize, m: usize, n: usize, k: usize, out: &mut cudarc::driver::CudaSlice<$rt>, + out_stride: &[usize], alpha: $rt, beta: $rt, ) -> crate::Result<()> { use crate::cuda_backend::error::WrapErr; use cudarc::cublas::Gemm; - let gemm_cfg = crate::cuda_backend::util::gemm_config(alpha, beta, (b, m, n, k))?; + let gemm_cfg = crate::cuda_backend::util::gemm_config( + alpha, + beta, + (b, m, n, k), + lhs_stride, + rhs_stride, + out_stride, + )?; unsafe { cublas diff --git a/constensor-core/src/error.rs b/constensor-core/src/error.rs index 8ed3fab..81a4588 100644 --- a/constensor-core/src/error.rs +++ b/constensor-core/src/error.rs @@ -28,10 +28,11 @@ pub enum Error { context: String, }, - #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} ostride: {out_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { - lhs_stride: [usize; 3], - rhs_stride: [usize; 3], + lhs_stride: Vec, + rhs_stride: Vec, + out_stride: Vec, mnk: (usize, usize, usize), }, } From 0845f4c4f60ab3f693791d19cab18ad9f188dbfa Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 19:08:37 +0000 Subject: [PATCH 09/11] Support strided cublas gemm --- constensor-core/examples/matmul/main.rs | 13 +++++++------ constensor-core/src/device.rs | 5 +++++ constensor-core/src/lib.rs | 2 +- constensor-core/src/tensor/graphtensor.rs | 4 ++-- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/constensor-core/examples/matmul/main.rs b/constensor-core/examples/matmul/main.rs index 25d218f..f194f9b 100644 --- a/constensor-core/examples/matmul/main.rs +++ b/constensor-core/examples/matmul/main.rs @@ -1,4 +1,4 @@ -use constensor_core::{CompiledGraph, Cpu, DType, Graph, GraphTensor, R3}; +use constensor_core::{BestDevice, CompiledGraph, DType, Graph, GraphTensor, R3}; use std::time::Instant; fn bench( @@ -11,14 +11,15 @@ fn bench, T, Cpu>::fill(&mut graph, T::from_f64(1.)); - let b = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(2.)).t(); - // let b = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(2.)); - let o = GraphTensor::, T, Cpu>::fill(&mut graph, T::from_f64(3.)); + let a = GraphTensor::, T, BestDevice<0>>::fill(&mut graph, T::from_f64(1.)); + // Strided matmuls works on all devices. + let b = GraphTensor::, T, BestDevice<0>>::fill(&mut graph, T::from_f64(2.)).t(); + // let b = GraphTensor::, T, BestDevice<0>>::fill(&mut graph, T::from_f64(2.)); + let o = GraphTensor::, T, BestDevice<0>>::fill(&mut graph, T::from_f64(3.)); let _c = a.matmul_axpby(b, o, alpha, beta); graph.optimize(); - let compiled: CompiledGraph, T, Cpu> = graph.compile().unwrap(); + let compiled: CompiledGraph, T, BestDevice<0>> = graph.compile().unwrap(); for _ in 0..iterations { let start = Instant::now(); diff --git a/constensor-core/src/device.rs b/constensor-core/src/device.rs index 7825235..fe03c45 100644 --- a/constensor-core/src/device.rs +++ b/constensor-core/src/device.rs @@ -57,6 +57,11 @@ cuda_device!(8); #[cfg(feature = "cuda")] cuda_device!(9); +#[cfg(feature = "cuda")] +pub type BestDevice = Cuda; +#[cfg(not(feature = "cuda"))] +pub type BestDevice = Cpu; + /// A concrete device. #[derive(Clone)] pub enum Device { diff --git a/constensor-core/src/lib.rs b/constensor-core/src/lib.rs index a614f3e..d328005 100644 --- a/constensor-core/src/lib.rs +++ b/constensor-core/src/lib.rs @@ -11,9 +11,9 @@ mod shape; mod storage; mod tensor; -pub use device::Cpu; #[cfg(feature = "cuda")] pub use device::Cuda; +pub use device::{BestDevice, Cpu}; pub use dtype::DType; pub use error::{Context, Error, Result}; pub use graph::{CompiledGraph, Graph, GraphNode, Op}; diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index 603f40f..006ac18 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -206,7 +206,7 @@ impl GraphTensor, T, let id = self.graph.write().unwrap().next_id(); - self.graph.write().unwrap().add_op::>( + self.graph.write().unwrap().add_op::>( Op::Permute { v_id: self.id.clone(), }, @@ -233,7 +233,7 @@ impl let id = self.graph.write().unwrap().next_id(); - self.graph.write().unwrap().add_op::>( + self.graph.write().unwrap().add_op::>( Op::Permute { v_id: self.id.clone(), }, From 335fc3098fef27e3d842fbaf1cf202cccfc4acfc Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 19:17:44 +0000 Subject: [PATCH 10/11] A small fix --- constensor-core/src/cuda_backend/util.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/constensor-core/src/cuda_backend/util.rs b/constensor-core/src/cuda_backend/util.rs index 73f26a7..512da2e 100644 --- a/constensor-core/src/cuda_backend/util.rs +++ b/constensor-core/src/cuda_backend/util.rs @@ -15,6 +15,7 @@ pub(crate) fn gemm_config( let lhs_dims = [b, m, k]; let rhs_dims = [b, k, n]; + let out_dims = [b, m, n]; let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; @@ -92,9 +93,9 @@ pub(crate) fn gemm_config( })?, }; let stride_c: usize = match out_stride[..out_stride.len() - 2] { - [s1, stride] if s1 == stride * rhs_dims[1] => stride, - [_, stride] if rhs_dims[0] == 1 => stride, - [stride, _] if rhs_dims[1] == 1 => stride, + [s1, stride] if s1 == stride * out_dims[1] => stride, + [_, stride] if out_dims[0] == 1 => stride, + [stride, _] if out_dims[1] == 1 => stride, [stride] => stride, [] => m * n, _ => Err(Error::MatMulNonContiguous { From fb9b2d49c516f2460fc3d96f0cf64dfb15ac267e Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 19:44:28 +0000 Subject: [PATCH 11/11] Works --- constensor-core/src/cuda_backend/mod.rs | 43 ++++++++++++++----------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 16448d6..99c2d57 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -13,7 +13,7 @@ use std::sync::{ }; use std::{ borrow::Cow, - collections::{HashMap, HashSet, VecDeque}, + collections::{HashMap, VecDeque}, fs, hash::{DefaultHasher, Hash, Hasher}, marker::PhantomData, @@ -518,20 +518,6 @@ impl BackendDevice for CudaDevice { let mut kernels = Vec::>::new(); let mut matmuls = Vec::>::new(); let mut splits: Vec<(Vec, Vec)> = Vec::new(); - // Collect all matmul input node indices - let mut matmul_inputs = HashSet::new(); - for &idx in &order { - if let Op::MatMul { - l_id, r_id, o_id, .. - } = &graph[idx].op - { - matmul_inputs.insert(l_id.get()); - matmul_inputs.insert(r_id.get()); - if let Some(o_id) = o_id { - matmul_inputs.insert(o_id.get()); - } - } - } for &idx in &order { match &graph[idx].op { @@ -606,11 +592,32 @@ impl BackendDevice for CudaDevice { } _ => { let shape_key = graph[idx].shape.clone(); + // Group only when same shape and this op depends on the last split node let should_group = if let Some((last_group, _)) = splits.last_mut() { let last_idx = *last_group.last().unwrap(); - let last_shape_key = graph[last_idx].shape.clone(); - // Force all matmul inputs to have their own - last_shape_key == shape_key && !matmul_inputs.contains(&idx) + if graph[last_idx].shape == shape_key { + match &graph[idx].op { + Op::BinaryOp { l_id, r_id, .. } => { + l_id.get() == last_idx || r_id.get() == last_idx + } + Op::UnaryOp { v_id, .. } => v_id.get() == last_idx, + Op::FusedMulAdd { a_id, b_id, c_id } => { + a_id.get() == last_idx + || b_id.get() == last_idx + || c_id.get() == last_idx + } + Op::Permute { v_id } => v_id.get() == last_idx, + // Init ops always start new group + Op::NoOp + | Op::Fill { .. } + | Op::Arange { .. } + | Op::Rand + | Op::Randn { .. } + | Op::MatMul { .. } => false, + } + } else { + false + } } else { false };