From 58833ea0d12f2329f05d1bd30966a072e8f1091c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 14:43:46 +0000 Subject: [PATCH 01/18] Fix dtype --- constensor-core/src/cuda_backend/mod.rs | 7 ++----- constensor-core/src/device.rs | 2 +- constensor-core/src/dtype/mod.rs | 1 - 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 47561a2..8dd00de 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -18,7 +18,7 @@ use crate::{ cpu_storage::CpuStorage, graph::GraphTensorId, storage::{BackendDevice, BackendStorage}, - DType, Op, Result, SignedDType, + DType, GraphNode, Op, Result, }; #[derive(Clone)] @@ -305,10 +305,7 @@ impl CudaDevice { impl BackendDevice for CudaDevice { type Storage = CudaStorage; - fn compile_and_run_graph( - &self, - nodes: &[crate::Op], - ) -> Result> { + fn compile_and_run_graph(&self, nodes: &[GraphNode]) -> Result> { let mut header = "".to_string(); let body = handle_node(&mut 0, &mut header, nodes.last().unwrap(), nodes); self.run_graph::(header, body) diff --git a/constensor-core/src/device.rs b/constensor-core/src/device.rs index 4c181fd..58bad50 100644 --- a/constensor-core/src/device.rs +++ b/constensor-core/src/device.rs @@ -69,7 +69,7 @@ impl Device { pub fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { match self { #[cfg(feature = "cuda")] - Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.compile_and_run_graph::(graph)?)), + Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.compile_and_run_graph::(graph)?)), Self::Cpu => Ok(Storage::Cpu(CpuDevice.compile_and_run_graph::(graph)?)), } } diff --git a/constensor-core/src/dtype/mod.rs b/constensor-core/src/dtype/mod.rs index 7a39cc6..56bf6bd 100644 --- a/constensor-core/src/dtype/mod.rs +++ b/constensor-core/src/dtype/mod.rs @@ -143,7 +143,6 @@ maybe_neg!(i64); maybe_neg!(f32); maybe_neg!(f64); -#[cfg(not(feature = "cuda"))] /// Marker trait for tensor datatypes. pub trait DType: Debug + Clone + DTypeOps + Send + Sync + MaybeNeg + DeviceReprLike + 'static From d842b811edf126231fa81ed809e66005a6ae8168 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 14:47:06 +0000 Subject: [PATCH 02/18] Hmm --- constensor-core/src/cuda_backend/mod.rs | 66 ++++++++----------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 8dd00de..b16708d 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -83,36 +83,21 @@ impl Name { fn handle_node( current_name: &mut usize, header: &mut String, - op: &Op, - graph: &[Op], + op: &GraphNode, + graph: &[GraphNode], ) -> String { - match op { + match &op.op { Op::BinaryOp { l_id, r_id, operator, } => { - let l_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(l_id)], - graph, - ); - let r_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(r_id)], - graph, - ); + let l_name = handle_node(current_name, header, &graph[l_id.get()], graph); + let r_name = handle_node(current_name, header, &graph[r_id.get()], graph); format!("({l_name} {} {r_name})", operator.as_c_op()) } Op::UnaryOp { v_id, operator } => { - let v_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(v_id)], - graph, - ); + let v_name = handle_node(current_name, header, &graph[v_id.get()], graph); operator.fill_in_c_op(v_name) } Op::Fill { v } => { @@ -132,24 +117,9 @@ fn handle_node( format!("({})", name.to_name()) } Op::FusedMulAdd { a_id, b_id, c_id } => { - let a_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(a_id)], - graph, - ); - let b_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(b_id)], - graph, - ); - let c_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(c_id)], - graph, - ); + let a_name = handle_node(current_name, header, &graph[a_id.get()], graph); + let b_name = handle_node(current_name, header, &graph[b_id.get()], graph); + let c_name = handle_node(current_name, header, &graph[c_id.get()], graph); #[cfg(feature = "slow_integral_fma_cuda")] if T::INTEGRAL { use crate::graph::BinaryOpType; @@ -163,6 +133,16 @@ fn handle_node( format!("( static_cast(fma(static_cast({a_name}), static_cast({b_name}), static_cast({c_name}))))") } Op::NoOp => unreachable!("no-op ops should never be reached."), + Op::MatMul { + l_id, + r_id, + o_id, + k, + alpha, + beta, + } => { + todo!() + } } } @@ -215,11 +195,7 @@ fn compile_ptx(template_kernel: String) -> Result { } impl CudaDevice { - fn run_graph( - &self, - header: String, - body: String, - ) -> Result> { + fn run_graph(&self, header: String, body: String) -> Result> { // Module name is based on hash of body and header let mut hasher = DefaultHasher::new(); body.hash(&mut hasher); @@ -308,6 +284,6 @@ impl BackendDevice for CudaDevice { fn compile_and_run_graph(&self, nodes: &[GraphNode]) -> Result> { let mut header = "".to_string(); let body = handle_node(&mut 0, &mut header, nodes.last().unwrap(), nodes); - self.run_graph::(header, body) + self.run_graph::(header, body) } } From df9ec56be233ba528f4763b5bfaff8bbc832ca8b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 15:01:57 +0000 Subject: [PATCH 03/18] Graph splits --- constensor-core/examples/hello_world/main.rs | 18 ++--- constensor-core/src/cuda_backend/mod.rs | 79 ++++++++++++++++++-- 2 files changed, 82 insertions(+), 15 deletions(-) diff --git a/constensor-core/examples/hello_world/main.rs b/constensor-core/examples/hello_world/main.rs index f167c8a..3a960eb 100644 --- a/constensor-core/examples/hello_world/main.rs +++ b/constensor-core/examples/hello_world/main.rs @@ -1,21 +1,21 @@ -use constensor_core::{Cpu, Graph, GraphTensor, Tensor, R1, R2}; +use constensor_core::{Cpu, Cuda, Graph, GraphTensor, Tensor, R1, R2}; fn main() { let mut graph: Graph = Graph::empty(); - let arange = GraphTensor::, f32, Cpu>::arange(&mut graph, 0., 1.); - dbg!(&arange.to_tensor().unwrap().data()); - let a = GraphTensor::, f32, Cpu>::fill(&mut graph, 1.0); - let b = GraphTensor::, f32, Cpu>::fill(&mut graph, 2.0); - let c = GraphTensor::, f32, Cpu>::fill(&mut graph, 3.0); - let d = GraphTensor::, f32, Cpu>::fill(&mut graph, 4.0); + // let arange = GraphTensor::, f32, Cuda<0>>::arange(&mut graph, 0., 1.); + // dbg!(&arange.to_tensor().unwrap().data()); + let a = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 1.0); + let b = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 2.0); + let c = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 3.0); + let d = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 4.0); let res = a * b + c; let res = res + d; graph.optimize(); - graph.visualize("graph.png").unwrap(); + // graph.visualize("graph.png").unwrap(); - let tensor: Tensor, f32, Cpu> = res.to_tensor().unwrap(); + let tensor: Tensor, f32, Cuda<0>> = res.to_tensor().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![9.0; 4]; 3],); } diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index b16708d..6083612 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -13,6 +13,7 @@ use cudarc::{ nvrtc::{CompileOptions, Ptx}, }; use error::WrapErr; +use petgraph::{algo::toposort, prelude::DiGraphMap}; use crate::{ cpu_storage::CpuStorage, @@ -107,7 +108,7 @@ fn handle_node( format!("({})", name.to_name()) } Op::Arange { start, step, stop } => { - compile_error!("arange is not implemented for CUDA yet."); + // compile_error!("arange is not implemented for CUDA yet."); *current_name += 1; let name = Name(*current_name); *header += &format!( @@ -257,7 +258,7 @@ impl CudaDevice { fs::write(path, ptx_str)?; } - let n_elems = S::element_count(); + let n_elems = 100; //S::element_count(); let stream = self.stream(); let data = unsafe { stream.alloc::(n_elems) }.w()?; @@ -281,9 +282,75 @@ impl CudaDevice { impl BackendDevice for CudaDevice { type Storage = CudaStorage; - fn compile_and_run_graph(&self, nodes: &[GraphNode]) -> Result> { - let mut header = "".to_string(); - let body = handle_node(&mut 0, &mut header, nodes.last().unwrap(), nodes); - self.run_graph::(header, body) + fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { + // Build a dependency graph of tensor indices + let mut dep_graph = DiGraphMap::::new(); + for idx in 0..graph.len() { + dep_graph.add_node(idx); + } + + for (idx, node) in graph.iter().enumerate() { + match &node.op { + Op::BinaryOp { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); + } + Op::UnaryOp { v_id, .. } => { + dep_graph.add_edge(v_id.get(), idx, ()); + } + Op::FusedMulAdd { a_id, b_id, c_id } => { + dep_graph.add_edge(a_id.get(), idx, ()); + dep_graph.add_edge(b_id.get(), idx, ()); + dep_graph.add_edge(c_id.get(), idx, ()); + } + Op::MatMul { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); + } + // NoOp and Fill/Arange don’t create incoming edges + Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} + } + } + + // Compute topological order + let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); + + // Split into groups of nodes whose input shapes and dtype match + let mut splits: Vec> = Vec::new(); + for &idx in &order { + // Determine a key based on this node's input shapes + let shape_key: Vec = graph[idx].shape.clone(); + let should_group = if let Some(last_group) = splits.last_mut() { + let last_idx = *last_group.last().unwrap(); + let last_shape_key = graph[last_idx].shape.clone(); + last_shape_key == shape_key + } else { + false + }; + if should_group { + splits.last_mut().unwrap().push(idx); + } else { + splits.push(vec![idx]); + } + } + + // For each group of nodes with matching input shapes/dtype, generate and run kernels + let mut last_storage = None; + for sub_order in splits { + // build header/body for this subgraph slice + let mut header = String::new(); + let body = handle_node( + &mut 0, + &mut header, + &graph[*sub_order.last().unwrap()], + graph, + ); + // launch a kernel for this subgroup + let storage = self.run_graph::(header.clone(), body.clone())?; + last_storage = Some(storage); + // handle or collect `storage` as needed; here we return the last one + } + // Return the last storage (corresponds to the output node) + Ok(last_storage.expect("No nodes to execute")) } } From 79dcc650d127ab86307fe0b5b2338752c12063f3 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 15:21:39 +0000 Subject: [PATCH 04/18] Graph splits with multiple functions --- constensor-core/examples/hello_world/main.rs | 4 +- constensor-core/src/cuda_backend/mod.rs | 65 +++++++++----------- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/constensor-core/examples/hello_world/main.rs b/constensor-core/examples/hello_world/main.rs index 3a960eb..2b8c07c 100644 --- a/constensor-core/examples/hello_world/main.rs +++ b/constensor-core/examples/hello_world/main.rs @@ -2,8 +2,8 @@ use constensor_core::{Cpu, Cuda, Graph, GraphTensor, Tensor, R1, R2}; fn main() { let mut graph: Graph = Graph::empty(); - // let arange = GraphTensor::, f32, Cuda<0>>::arange(&mut graph, 0., 1.); - // dbg!(&arange.to_tensor().unwrap().data()); + let arange = GraphTensor::, f32, Cuda<0>>::arange(&mut graph, 0., 1.); + dbg!(&arange.to_tensor().unwrap().data()); let a = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 1.0); let b = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 2.0); let c = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 3.0); diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 6083612..b2f5ccc 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -1,11 +1,11 @@ use std::{ borrow::Cow, - cell::OnceCell, + collections::HashMap, fs, hash::{DefaultHasher, Hash, Hasher}, ops::Deref, path::{Path, PathBuf}, - sync::Arc, + sync::{Arc, RwLock}, }; mod error; use cudarc::{ @@ -17,7 +17,6 @@ use petgraph::{algo::toposort, prelude::DiGraphMap}; use crate::{ cpu_storage::CpuStorage, - graph::GraphTensorId, storage::{BackendDevice, BackendStorage}, DType, GraphNode, Op, Result, }; @@ -26,7 +25,7 @@ use crate::{ pub struct CudaDevice { context: Arc, stream: Arc, - module: OnceCell>, + modules: Arc>>>, } impl CudaDevice { @@ -36,7 +35,7 @@ impl CudaDevice { Ok(Self { context, stream, - module: OnceCell::new(), + modules: Arc::new(RwLock::new(vec![])), }) } @@ -44,11 +43,11 @@ impl CudaDevice { self.stream.clone() } - pub(crate) fn get_or_load_func(&self, function_name: &str, ptx: Ptx) -> Result { - let module = self - .module - .get_or_init(|| self.context.load_module(ptx).w().unwrap()); - module.load_function(function_name).w() + pub(crate) fn load_func(&self, function_name: &str, ptx: Ptx) -> Result { + let module = self.context.load_module(ptx).w()?; + let func = module.load_function(function_name).w()?; + self.modules.write().unwrap().push(module); + Ok(func) } } @@ -108,7 +107,7 @@ fn handle_node( format!("({})", name.to_name()) } Op::Arange { start, step, stop } => { - // compile_error!("arange is not implemented for CUDA yet."); + // todo!(); *current_name += 1; let name = Name(*current_name); *header += &format!( @@ -134,16 +133,7 @@ fn handle_node( format!("( static_cast(fma(static_cast({a_name}), static_cast({b_name}), static_cast({c_name}))))") } Op::NoOp => unreachable!("no-op ops should never be reached."), - Op::MatMul { - l_id, - r_id, - o_id, - k, - alpha, - beta, - } => { - todo!() - } + Op::MatMul { .. } => unreachable!("matmul op should have its own split!"), } } @@ -196,7 +186,12 @@ fn compile_ptx(template_kernel: String) -> Result { } impl CudaDevice { - fn run_graph(&self, header: String, body: String) -> Result> { + fn run_graph( + &self, + header: String, + body: String, + shape: Vec, + ) -> Result> { // Module name is based on hash of body and header let mut hasher = DefaultHasher::new(); body.hash(&mut hasher); @@ -258,12 +253,12 @@ impl CudaDevice { fs::write(path, ptx_str)?; } - let n_elems = 100; //S::element_count(); + let n_elems = shape.iter().product(); let stream = self.stream(); let data = unsafe { stream.alloc::(n_elems) }.w()?; - let func = self.get_or_load_func(&function_name, ptx)?; + let func = self.load_func(&function_name, ptx)?; let cfg = LaunchConfig::for_num_elems(n_elems as u32); @@ -316,11 +311,11 @@ impl BackendDevice for CudaDevice { let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); // Split into groups of nodes whose input shapes and dtype match - let mut splits: Vec> = Vec::new(); + let mut splits: Vec<(Vec, Vec)> = Vec::new(); for &idx in &order { // Determine a key based on this node's input shapes let shape_key: Vec = graph[idx].shape.clone(); - let should_group = if let Some(last_group) = splits.last_mut() { + let should_group = if let Some((last_group, _)) = splits.last_mut() { let last_idx = *last_group.last().unwrap(); let last_shape_key = graph[last_idx].shape.clone(); last_shape_key == shape_key @@ -328,15 +323,15 @@ impl BackendDevice for CudaDevice { false }; if should_group { - splits.last_mut().unwrap().push(idx); + splits.last_mut().unwrap().0.push(idx); } else { - splits.push(vec![idx]); + splits.push((vec![idx], shape_key)); } } // For each group of nodes with matching input shapes/dtype, generate and run kernels - let mut last_storage = None; - for sub_order in splits { + let mut last_storage = HashMap::new(); + for (sub_order, shape) in splits { // build header/body for this subgraph slice let mut header = String::new(); let body = handle_node( @@ -346,11 +341,11 @@ impl BackendDevice for CudaDevice { graph, ); // launch a kernel for this subgroup - let storage = self.run_graph::(header.clone(), body.clone())?; - last_storage = Some(storage); - // handle or collect `storage` as needed; here we return the last one + let storage = self.run_graph::(header.clone(), body.clone(), shape)?; + last_storage.insert(*sub_order.iter().max().unwrap(), storage); } - // Return the last storage (corresponds to the output node) - Ok(last_storage.expect("No nodes to execute")) + + let key = *last_storage.keys().max().unwrap(); + Ok(last_storage.remove(&key).unwrap()) } } From 1e5e7d7f7308e0fcad9fe7ea8e91afb674ed8722 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 15:45:40 +0000 Subject: [PATCH 05/18] Split compile from run --- constensor-core/examples/hello_world/main.rs | 9 +- constensor-core/src/cpu_storage/mod.rs | 96 +++++++++++++------- constensor-core/src/cuda_backend/mod.rs | 82 +++++++++++++---- constensor-core/src/device.rs | 22 ++++- constensor-core/src/graph.rs | 35 ++++++- constensor-core/src/lib.rs | 2 +- constensor-core/src/storage.rs | 11 ++- constensor-core/src/tensor/graphtensor.rs | 13 +-- 8 files changed, 191 insertions(+), 79 deletions(-) diff --git a/constensor-core/examples/hello_world/main.rs b/constensor-core/examples/hello_world/main.rs index 2b8c07c..b115a9d 100644 --- a/constensor-core/examples/hello_world/main.rs +++ b/constensor-core/examples/hello_world/main.rs @@ -2,20 +2,21 @@ use constensor_core::{Cpu, Cuda, Graph, GraphTensor, Tensor, R1, R2}; fn main() { let mut graph: Graph = Graph::empty(); - let arange = GraphTensor::, f32, Cuda<0>>::arange(&mut graph, 0., 1.); - dbg!(&arange.to_tensor().unwrap().data()); + // let arange = GraphTensor::, f32, Cuda<0>>::arange(&mut graph, 0., 1.); let a = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 1.0); let b = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 2.0); let c = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 3.0); let d = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 4.0); let res = a * b + c; - let res = res + d; + let _out = res + d; graph.optimize(); + let compiled = graph.compile().unwrap(); + let res: Tensor, f32, _> = compiled.run().unwrap(); // graph.visualize("graph.png").unwrap(); - let tensor: Tensor, f32, Cuda<0>> = res.to_tensor().unwrap(); + let tensor: Tensor, f32, Cuda<0>> = res; assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![9.0; 4]; 3],); } diff --git a/constensor-core/src/cpu_storage/mod.rs b/constensor-core/src/cpu_storage/mod.rs index 3dd73c1..9cdeb7d 100644 --- a/constensor-core/src/cpu_storage/mod.rs +++ b/constensor-core/src/cpu_storage/mod.rs @@ -1,15 +1,17 @@ use petgraph::algo::toposort; use petgraph::graphmap::DiGraphMap; -use std::borrow::Cow; use std::cell::RefCell; use std::rc::Rc; +use std::{borrow::Cow, marker::PhantomData}; use pool::{BufferPool, PooledBuffer}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +use crate::device::Dev; +use crate::Shape; use crate::{ storage::{BackendDevice, BackendStorage}, - DType, GraphNode, Op, Result, + CompiledGraph, DType, GraphNode, Op, Result, }; mod pool; @@ -29,49 +31,73 @@ impl BackendStorage for CpuStorage { impl BackendDevice for CpuDevice { type Storage = CpuStorage; - fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { - { - // Create a shared buffer pool - let pool = Rc::new(RefCell::new(BufferPool::::new())); - - // Build a dependency graph of tensor indices - let mut dep_graph = DiGraphMap::::new(); - for idx in 0..graph.len() { - dep_graph.add_node(idx); - } + fn compile( + &self, + graph: Vec>, + ) -> Result> { + // Build a dependency graph of tensor indices + let mut dep_graph = DiGraphMap::::new(); + for idx in 0..graph.len() { + dep_graph.add_node(idx); + } - for (idx, node) in graph.iter().enumerate() { - match &node.op { - Op::BinaryOp { l_id, r_id, .. } => { - dep_graph.add_edge(l_id.get(), idx, ()); - dep_graph.add_edge(r_id.get(), idx, ()); - } - Op::UnaryOp { v_id, .. } => { - dep_graph.add_edge(v_id.get(), idx, ()); - } - Op::FusedMulAdd { a_id, b_id, c_id } => { - dep_graph.add_edge(a_id.get(), idx, ()); - dep_graph.add_edge(b_id.get(), idx, ()); - dep_graph.add_edge(c_id.get(), idx, ()); - } - Op::MatMul { l_id, r_id, .. } => { - dep_graph.add_edge(l_id.get(), idx, ()); - dep_graph.add_edge(r_id.get(), idx, ()); - } - // NoOp and Fill/Arange don’t create incoming edges - Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} + for (idx, node) in graph.iter().enumerate() { + match &node.op { + Op::BinaryOp { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); } + Op::UnaryOp { v_id, .. } => { + dep_graph.add_edge(v_id.get(), idx, ()); + } + Op::FusedMulAdd { a_id, b_id, c_id } => { + dep_graph.add_edge(a_id.get(), idx, ()); + dep_graph.add_edge(b_id.get(), idx, ()); + dep_graph.add_edge(c_id.get(), idx, ()); + } + Op::MatMul { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); + } + // NoOp and Fill/Arange don’t create incoming edges + Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} } + } + + // Compute topological order + let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); + + Ok(CompiledGraph::Cpu { + order, + graph, + ghost: PhantomData, + }) + } + + fn run_graph( + &self, + graph: &CompiledGraph, + ) -> Result> { + { + // Create a shared buffer pool + let pool = Rc::new(RefCell::new(BufferPool::::new())); - // Compute topological order - let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); + #[allow(irrefutable_let_patterns)] + let CompiledGraph::Cpu { + order, + graph, + ghost: _, + } = graph + else { + unreachable!() + }; // Prepare storage for intermediate results let mut results: Vec>> = Vec::with_capacity(graph.len()); results.resize_with(graph.len(), || None); // Evaluate nodes in topological order - for idx in order { + for idx in order.clone() { let op = &graph[idx]; let out_shape = &op.shape; diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index b2f5ccc..79f128a 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -3,6 +3,7 @@ use std::{ collections::HashMap, fs, hash::{DefaultHasher, Hash, Hasher}, + marker::PhantomData, ops::Deref, path::{Path, PathBuf}, sync::{Arc, RwLock}, @@ -17,8 +18,9 @@ use petgraph::{algo::toposort, prelude::DiGraphMap}; use crate::{ cpu_storage::CpuStorage, + device::Dev, storage::{BackendDevice, BackendStorage}, - DType, GraphNode, Op, Result, + CompiledGraph, DType, GraphNode, Op, Result, Shape, }; #[derive(Clone)] @@ -186,12 +188,34 @@ fn compile_ptx(template_kernel: String) -> Result { } impl CudaDevice { - fn run_graph( + fn run_kernel( + &self, + func: &CudaFunction, + data: &CudaSlice, + shape: &Vec, + ) -> Result> { + let n_elems: usize = shape.iter().product(); + let stream = self.stream(); + + let cfg = LaunchConfig::for_num_elems(n_elems as u32); + + let mut builder = stream.launch_builder(&func); + builder.arg(data); + builder.arg(&n_elems); + unsafe { builder.launch(cfg).w()? }; + + Ok(CudaStorage { + slice: data.clone(), + device: self.clone(), + }) + } + + fn compile_kernel( &self, header: String, body: String, shape: Vec, - ) -> Result> { + ) -> Result<(CudaFunction, CudaSlice)> { // Module name is based on hash of body and header let mut hasher = DefaultHasher::new(); body.hash(&mut hasher); @@ -260,24 +284,17 @@ impl CudaDevice { let func = self.load_func(&function_name, ptx)?; - let cfg = LaunchConfig::for_num_elems(n_elems as u32); - - let mut builder = stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&n_elems); - unsafe { builder.launch(cfg).w()? }; - - Ok(CudaStorage { - slice: data, - device: self.clone(), - }) + Ok((func, data)) } } impl BackendDevice for CudaDevice { type Storage = CudaStorage; - fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { + fn compile( + &self, + graph: Vec>, + ) -> Result> { // Build a dependency graph of tensor indices let mut dep_graph = DiGraphMap::::new(); for idx in 0..graph.len() { @@ -329,8 +346,8 @@ impl BackendDevice for CudaDevice { } } - // For each group of nodes with matching input shapes/dtype, generate and run kernels - let mut last_storage = HashMap::new(); + // For each group of nodes with matching input shapes/dtype, generate kernels + let mut kernels = Vec::new(); for (sub_order, shape) in splits { // build header/body for this subgraph slice let mut header = String::new(); @@ -338,11 +355,36 @@ impl BackendDevice for CudaDevice { &mut 0, &mut header, &graph[*sub_order.last().unwrap()], - graph, + &graph, ); // launch a kernel for this subgroup - let storage = self.run_graph::(header.clone(), body.clone(), shape)?; - last_storage.insert(*sub_order.iter().max().unwrap(), storage); + let (func, slice) = + self.compile_kernel::(header.clone(), body.clone(), shape.clone())?; + kernels.push((func, slice, shape, *sub_order.iter().max().unwrap())) + } + + Ok(CompiledGraph::Cuda { + kernels, + ghost: PhantomData, + }) + } + + fn run_graph( + &self, + graph: &CompiledGraph, + ) -> Result> { + #[allow(irrefutable_let_patterns)] + let CompiledGraph::Cuda { kernels, ghost: _ } = graph + else { + unreachable!() + }; + + // For each group of nodes with matching input shapes/dtype, generate and run kernels + let mut last_storage = HashMap::new(); + for (func, slice, shape, order) in kernels { + // launch a kernel for this subgroup + let storage = self.run_kernel::(func, slice, shape)?; + last_storage.insert(order, storage); } let key = *last_storage.keys().max().unwrap(); diff --git a/constensor-core/src/device.rs b/constensor-core/src/device.rs index 58bad50..7825235 100644 --- a/constensor-core/src/device.rs +++ b/constensor-core/src/device.rs @@ -3,7 +3,7 @@ use crate::cuda_backend::CudaDevice; use crate::{ cpu_storage::CpuDevice, storage::{BackendDevice, Storage}, - DType, GraphNode, Result, + CompiledGraph, DType, GraphNode, Result, Shape, }; /// Marker trait for devices @@ -66,11 +66,25 @@ pub enum Device { } impl Device { - pub fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { + pub fn run_graph( + &self, + graph: &CompiledGraph, + ) -> Result> { match self { #[cfg(feature = "cuda")] - Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.compile_and_run_graph::(graph)?)), - Self::Cpu => Ok(Storage::Cpu(CpuDevice.compile_and_run_graph::(graph)?)), + Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.run_graph::(graph)?)), + Self::Cpu => Ok(Storage::Cpu(CpuDevice.run_graph::(graph)?)), + } + } + + pub fn compile( + &self, + graph: Vec>, + ) -> Result> { + match self { + #[cfg(feature = "cuda")] + Self::Cuda(cuda) => cuda.compile::(graph), + Self::Cpu => CpuDevice.compile::(graph), } } } diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index f71f229..3920218 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -5,13 +5,14 @@ use std::{ fmt::Display, fs, hash::Hash, + marker::PhantomData, path::Path, process::Command, rc::Rc, sync::{Arc, RwLock, RwLockReadGuard}, }; -use crate::{DType, Result, Shape}; +use crate::{device::Dev, tensor::concretetensor::from_storage, DType, Result, Shape, Tensor}; use petgraph::Graph as PetGraph; use petgraph::{dot::Dot, graph::NodeIndex}; @@ -441,6 +442,38 @@ impl Graph { self.optimize_inplace_fma(); self.optimize_inplace_matmul(); } + + pub fn compile(self) -> Result> { + let device = D::resolve()?; + + device.compile(self.data.read().unwrap().clone()) + } +} + +pub enum CompiledGraph { + Cpu { + order: Vec, + graph: Vec>, + ghost: PhantomData<(S, T, D)>, + }, + #[cfg(feature = "cuda")] + Cuda { + kernels: Vec<( + cudarc::driver::CudaFunction, + cudarc::driver::CudaSlice, + Vec, + usize, + )>, + ghost: PhantomData<(S, T, D)>, + }, +} + +impl CompiledGraph { + pub fn run(&self) -> Result> { + let device = D::resolve()?; + let storage = device.run_graph(self)?; + Ok(from_storage(Arc::new(storage))) + } } #[derive(PartialEq, Debug, Clone, Copy)] diff --git a/constensor-core/src/lib.rs b/constensor-core/src/lib.rs index c8af30a..c8dffc9 100644 --- a/constensor-core/src/lib.rs +++ b/constensor-core/src/lib.rs @@ -16,6 +16,6 @@ pub use device::Cpu; pub use device::Cuda; pub use dtype::DType; pub use error::{Error, Result}; -pub use graph::{Graph, GraphNode, Op}; +pub use graph::{CompiledGraph, Graph, GraphNode, Op}; pub use shape::{Shape, R1, R2, R3, R4, R5, R6}; pub use tensor::{GraphTensor, Tensor}; diff --git a/constensor-core/src/storage.rs b/constensor-core/src/storage.rs index 2d10bf6..8394207 100644 --- a/constensor-core/src/storage.rs +++ b/constensor-core/src/storage.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; #[cfg(feature = "cuda")] use crate::cuda_backend::CudaStorage; -use crate::{cpu_storage::CpuStorage, DType, GraphNode, Result}; +use crate::{cpu_storage::CpuStorage, device::Dev, CompiledGraph, DType, GraphNode, Result, Shape}; pub enum Storage { #[cfg(feature = "cuda")] @@ -27,5 +27,12 @@ pub trait BackendStorage { pub trait BackendDevice { type Storage: BackendStorage; - fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result>; + fn compile( + &self, + graph: Vec>, + ) -> Result>; + fn run_graph( + &self, + graph: &CompiledGraph, + ) -> Result>; } diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index 6269718..1872a4d 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -7,8 +7,7 @@ use std::{ use crate::{ device::Dev, graph::{BinaryOpType, Graph, GraphTensorId, Op, UnaryOpType}, - tensor::concretetensor::from_storage, - DType, Result, Shape, Tensor, R1, R3, + DType, Shape, R1, R3, }; /// A tensor representing an intermediary result of a graph. Performing operations @@ -126,16 +125,6 @@ impl GraphTensor { pub fn id(&self) -> GraphTensorId { self.id.clone() } - - /// Convert this `GraphTensor` into a concrete `Tensor`. - pub fn to_tensor(self) -> Result> { - let graph = self.graph.read().unwrap(); - let nodes = &*graph.get_ops(); - - let device = D::resolve()?; - let storage = device.compile_and_run_graph::(nodes)?; - Ok(from_storage(Arc::new(storage))) - } } impl GraphTensor, T, D> { From fb314a006c6952348a46fc2a8247889bb42caf7e Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 15:48:11 +0000 Subject: [PATCH 06/18] Cuda compiled kernel --- constensor-core/src/cuda_backend/mod.rs | 22 ++++++++++++++++++++-- constensor-core/src/graph.rs | 7 +------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 79f128a..3780772 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -73,6 +73,13 @@ impl BackendStorage for CudaStorage { } } +pub struct CudaCompiledKernel { + func: CudaFunction, + slice: CudaSlice, + shape: Vec, + order: usize, +} + #[derive(Debug)] struct Name(usize); impl Name { @@ -360,7 +367,12 @@ impl BackendDevice for CudaDevice { // launch a kernel for this subgroup let (func, slice) = self.compile_kernel::(header.clone(), body.clone(), shape.clone())?; - kernels.push((func, slice, shape, *sub_order.iter().max().unwrap())) + kernels.push(CudaCompiledKernel { + func, + slice, + shape, + order: *sub_order.iter().max().unwrap(), + }) } Ok(CompiledGraph::Cuda { @@ -381,7 +393,13 @@ impl BackendDevice for CudaDevice { // For each group of nodes with matching input shapes/dtype, generate and run kernels let mut last_storage = HashMap::new(); - for (func, slice, shape, order) in kernels { + for CudaCompiledKernel { + func, + slice, + shape, + order, + } in kernels + { // launch a kernel for this subgroup let storage = self.run_kernel::(func, slice, shape)?; last_storage.insert(order, storage); diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 3920218..15db280 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -458,12 +458,7 @@ pub enum CompiledGraph { }, #[cfg(feature = "cuda")] Cuda { - kernels: Vec<( - cudarc::driver::CudaFunction, - cudarc::driver::CudaSlice, - Vec, - usize, - )>, + kernels: Vec>, ghost: PhantomData<(S, T, D)>, }, } From 001d65565f8d0c3e383e9b2948f829b2deb7b4c1 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 20:23:43 +0000 Subject: [PATCH 07/18] Fix tests --- constensor-core/examples/hello_world/main.rs | 21 +++-- constensor-core/examples/matmul/main.rs | 21 +++-- constensor-core/src/graph.rs | 1 + constensor-core/tests/fma.rs | 22 +++-- constensor-core/tests/ops.rs | 97 ++++++++++++-------- 5 files changed, 94 insertions(+), 68 deletions(-) diff --git a/constensor-core/examples/hello_world/main.rs b/constensor-core/examples/hello_world/main.rs index b115a9d..372476e 100644 --- a/constensor-core/examples/hello_world/main.rs +++ b/constensor-core/examples/hello_world/main.rs @@ -1,22 +1,23 @@ -use constensor_core::{Cpu, Cuda, Graph, GraphTensor, Tensor, R1, R2}; +use constensor_core::{Cpu, Graph, GraphTensor, Tensor, R1, R2}; fn main() { let mut graph: Graph = Graph::empty(); - // let arange = GraphTensor::, f32, Cuda<0>>::arange(&mut graph, 0., 1.); - let a = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 1.0); - let b = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 2.0); - let c = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 3.0); - let d = GraphTensor::, f32, Cuda<0>>::fill(&mut graph, 4.0); + let _arange = GraphTensor::, f32, Cpu>::arange(&mut graph, 0., 1.); + let a = GraphTensor::, f32, Cpu>::fill(&mut graph, 1.0); + let b = GraphTensor::, f32, Cpu>::fill(&mut graph, 2.0); + let c = GraphTensor::, f32, Cpu>::fill(&mut graph, 3.0); + let d = GraphTensor::, f32, Cpu>::fill(&mut graph, 4.0); let res = a * b + c; let _out = res + d; graph.optimize(); - let compiled = graph.compile().unwrap(); - let res: Tensor, f32, _> = compiled.run().unwrap(); - // graph.visualize("graph.png").unwrap(); + graph.visualize("graph.png").unwrap(); - let tensor: Tensor, f32, Cuda<0>> = res; + let compiled: constensor_core::CompiledGraph, f32, Cpu> = graph.compile().unwrap(); + let res = compiled.run().unwrap(); + + let tensor: Tensor, f32, Cpu> = res; assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![9.0; 4]; 3],); } diff --git a/constensor-core/examples/matmul/main.rs b/constensor-core/examples/matmul/main.rs index ce2897f..8dceab2 100644 --- a/constensor-core/examples/matmul/main.rs +++ b/constensor-core/examples/matmul/main.rs @@ -1,4 +1,4 @@ -use constensor_core::{Cpu, DType, Graph, GraphTensor, R3}; +use constensor_core::{CompiledGraph, Cpu, DType, Graph, GraphTensor, R3}; use std::time::Instant; fn bench( @@ -10,18 +10,19 @@ fn bench, T, Cpu>::ones(&mut graph); + let b = GraphTensor::, T, Cpu>::ones(&mut graph); + let o = GraphTensor::, T, Cpu>::ones(&mut graph); + let _c = a.matmul_axpby(b, o, alpha, beta); - let mut graph = Graph::empty(); - let a = GraphTensor::, T, Cpu>::ones(&mut graph); - let b = GraphTensor::, T, Cpu>::ones(&mut graph); - let o = GraphTensor::, T, Cpu>::ones(&mut graph); - let c = a.matmul_axpby(b, o, alpha, beta); + graph.optimize(); + let compiled: CompiledGraph, T, Cpu> = graph.compile().unwrap(); - graph.optimize(); + for _ in 0..iterations { + let start = Instant::now(); - let _tensor = std::hint::black_box(c.to_tensor().unwrap()); + let _tensor = std::hint::black_box(compiled.run().unwrap()); total += start.elapsed(); } diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 15db280..6f00c67 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -450,6 +450,7 @@ impl Graph { } } +/// A representation of the compiled graph. The shape is the output shape. pub enum CompiledGraph { Cpu { order: Vec, diff --git a/constensor-core/tests/fma.rs b/constensor-core/tests/fma.rs index ed3a838..1bf9ca3 100644 --- a/constensor-core/tests/fma.rs +++ b/constensor-core/tests/fma.rs @@ -1,6 +1,6 @@ #[cfg(feature = "cuda")] use constensor_core::Cuda; -use constensor_core::{Cpu, Graph, GraphTensor, R2}; +use constensor_core::{CompiledGraph, Cpu, Graph, GraphTensor, R2}; macro_rules! test_for_device_fma { ($dev:ty, $name:ident) => { @@ -12,8 +12,9 @@ macro_rules! test_for_device_fma { let a = GraphTensor::, f32, $dev>::fill(&mut graph, 2.0); let b = GraphTensor::, f32, $dev>::fill(&mut graph, 3.0); let c = GraphTensor::, f32, $dev>::fill(&mut graph, 4.0); - let res = a * b + c; - let tensor = res.to_tensor().unwrap(); + let _res = a * b + c; + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![10.0; 4]; 3],); } @@ -23,8 +24,9 @@ macro_rules! test_for_device_fma { let a = GraphTensor::, i32, $dev>::fill(&mut graph, 2); let b = GraphTensor::, i32, $dev>::fill(&mut graph, 3); let c = GraphTensor::, i32, $dev>::fill(&mut graph, 4); - let res = a * b + c; - let tensor = res.to_tensor().unwrap(); + let _res = a * b + c; + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![10; 4]; 3],); } } @@ -47,8 +49,9 @@ macro_rules! test_for_device_half_fma { GraphTensor::, f16, $dev>::fill(&mut graph, f16::from_f64_const(3.0)); let c = GraphTensor::, f16, $dev>::fill(&mut graph, f16::from_f64_const(4.0)); - let res = a * b + c; - let tensor = res.to_tensor().unwrap(); + let _res = a * b + c; + let compiled: CompiledGraph, f16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![f16::from_f64_const(10.0); 4]; 3], @@ -80,8 +83,9 @@ macro_rules! test_for_device_bfloat_fma { &mut graph, bf16::from_f64_const(4.0), ); - let res = a * b + c; - let tensor = res.to_tensor().unwrap(); + let _res = a * b + c; + let compiled: CompiledGraph, bf16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![bf16::from_f64_const(10.0); 4]; 3], diff --git a/constensor-core/tests/ops.rs b/constensor-core/tests/ops.rs index deba8ec..c474b6e 100644 --- a/constensor-core/tests/ops.rs +++ b/constensor-core/tests/ops.rs @@ -1,6 +1,6 @@ #[cfg(feature = "cuda")] use constensor_core::Cuda; -use constensor_core::{Cpu, Graph, GraphTensor, R1, R2, R3}; +use constensor_core::{CompiledGraph, Cpu, Graph, GraphTensor, R1, R2, R3}; #[cfg(feature = "bfloat")] use half::bf16; #[cfg(feature = "half")] @@ -13,8 +13,9 @@ macro_rules! test_for_device_float { #[test] fn fill() { let mut graph = Graph::empty(); - let gt = GraphTensor::, f32, $dev>::fill(&mut graph, 0.0); - let tensor = gt.to_tensor().unwrap(); + let _gt = GraphTensor::, f32, $dev>::fill(&mut graph, 0.0); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![ @@ -32,8 +33,9 @@ macro_rules! test_for_device_float { let y = GraphTensor::, f32, $dev>::fill(&mut graph, 2.0); let z = GraphTensor::, f32, $dev>::fill(&mut graph, 4.0); let c = x + y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![ @@ -49,8 +51,9 @@ macro_rules! test_for_device_float { let mut graph = Graph::empty(); let x = GraphTensor::, f32, $dev>::fill(&mut graph, 1.0); let y = GraphTensor::, f32, $dev>::arange(&mut graph, 0.0, 1.0); - let res = x + y; - let tensor = res.to_tensor().unwrap(); + let _res = x + y; + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![1.0, 1.25, 1.5, 1.75]); } @@ -59,8 +62,9 @@ macro_rules! test_for_device_float { let mut graph = Graph::empty(); let a = GraphTensor::, f32, $dev>::ones(&mut graph); let b = GraphTensor::, f32, $dev>::ones(&mut graph); - let c = a.matmul(b); - let tensor = c.to_tensor().unwrap(); + let _c = a.matmul(b); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); let expected: [Vec<[f32; 2]>; 1] = [vec![[3.0, 3.0], [3.0, 3.0]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); } @@ -71,8 +75,9 @@ macro_rules! test_for_device_float { let a = GraphTensor::, f32, $dev>::ones(&mut graph); let b = GraphTensor::, f32, $dev>::ones(&mut graph); let o = GraphTensor::, f32, $dev>::ones(&mut graph); - let c = a.matmul_axpby(b, o, 1., 1.); - let tensor = c.to_tensor().unwrap(); + let _c = a.matmul_axpby(b, o, 1., 1.); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); let expected: [Vec<[f32; 2]>; 1] = [vec![[4.0, 4.0], [4.0, 4.0]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); } @@ -91,8 +96,9 @@ macro_rules! test_for_device_int { #[test] fn fill() { let mut graph = Graph::empty(); - let gt = GraphTensor::, i32, $dev>::fill(&mut graph, 0); - let tensor = gt.to_tensor().unwrap(); + let _gt = GraphTensor::, i32, $dev>::fill(&mut graph, 0); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![[0, 0, 0, 0,], [0, 0, 0, 0,], [0, 0, 0, 0,],], @@ -106,8 +112,9 @@ macro_rules! test_for_device_int { let y = GraphTensor::, i32, $dev>::fill(&mut graph, 2); let z = GraphTensor::, i32, $dev>::fill(&mut graph, 4); let c = x + y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![[1, 1, 1, 1,], [1, 1, 1, 1,], [1, 1, 1, 1,],], @@ -119,8 +126,9 @@ macro_rules! test_for_device_int { let mut graph = Graph::empty(); let x = GraphTensor::, i32, $dev>::fill(&mut graph, 1); let y = GraphTensor::, i32, $dev>::arange(&mut graph, 0, 4); - let res = x + y; - let tensor = res.to_tensor().unwrap(); + let _res = x + y; + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![1, 2, 3, 4]); } @@ -129,8 +137,9 @@ macro_rules! test_for_device_int { let mut graph = Graph::empty(); let a = GraphTensor::, i32, $dev>::ones(&mut graph); let b = GraphTensor::, i32, $dev>::ones(&mut graph); - let c = a.matmul(b); - let tensor = c.to_tensor().unwrap(); + let _c = a.matmul(b); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); let expected: [Vec<[i32; 2]>; 1] = [vec![[3, 3], [3, 3]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); } @@ -141,8 +150,9 @@ macro_rules! test_for_device_int { let a = GraphTensor::, i32, $dev>::ones(&mut graph); let b = GraphTensor::, i32, $dev>::ones(&mut graph); let o = GraphTensor::, i32, $dev>::ones(&mut graph); - let c = a.matmul_axpby(b, o, 1, 1); - let tensor = c.to_tensor().unwrap(); + let _c = a.matmul_axpby(b, o, 1, 1); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); let expected: [Vec<[i32; 2]>; 1] = [vec![[4, 4], [4, 4]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); } @@ -162,9 +172,10 @@ macro_rules! test_for_device_half { #[test] fn fill() { let mut graph = Graph::empty(); - let gt = + let _gt = GraphTensor::, f16, $dev>::fill(&mut graph, f16::from_f64_const(0.0)); - let tensor = gt.to_tensor().unwrap(); + let compiled: CompiledGraph, f16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![f16::from_f64_const(0.0); 4]; 3], @@ -181,8 +192,9 @@ macro_rules! test_for_device_half { let z = GraphTensor::, f16, $dev>::fill(&mut graph, f16::from_f64_const(4.0)); let c = x + y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, f16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![f16::from_f64_const(1.3330078); 4]; 3], @@ -198,8 +210,9 @@ macro_rules! test_for_device_half { f16::from_f64_const(0.0), f16::from_f64_const(1.0), ); - let res = x + y; - let tensor = res.to_tensor().unwrap(); + let _res = x + y; + let compiled: CompiledGraph, f16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![ @@ -227,11 +240,12 @@ macro_rules! test_for_device_bfloat { #[test] fn fill() { let mut graph = Graph::empty(); - let gt = GraphTensor::, bf16, $dev>::fill( + let _gt = GraphTensor::, bf16, $dev>::fill( &mut graph, bf16::from_f64_const(0.0), ); - let tensor = gt.to_tensor().unwrap(); + let compiled: CompiledGraph, bf16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![bf16::from_f64_const(0.0); 4]; 3], @@ -254,8 +268,9 @@ macro_rules! test_for_device_bfloat { bf16::from_f64_const(4.0), ); let c = x + y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, bf16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![bf16::from_f64_const(1.3330078); 4]; 3], @@ -272,8 +287,9 @@ macro_rules! test_for_device_bfloat { bf16::from_f64_const(0.0), bf16::from_f64_const(1.0), ); - let res = x + y; - let tensor = res.to_tensor().unwrap(); + let _res = x + y; + let compiled: CompiledGraph, bf16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![ @@ -305,8 +321,9 @@ macro_rules! test_for_device_float_unary { let y = GraphTensor::, f32, $dev>::fill(&mut graph, 2.0); let z = GraphTensor::, f32, $dev>::fill(&mut graph, 4.0); let c = x + -y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![-4.0; 4]; 3],); } } @@ -326,8 +343,9 @@ macro_rules! test_for_device_sqrt { fn sqrt_float() { let mut graph = Graph::empty(); let x = GraphTensor::, f32, $dev>::fill(&mut graph, 4.0); - let res = x.sqrt(); - let tensor = res.to_tensor().unwrap(); + let _res = x.sqrt(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![2.0; 4]; 3],); } @@ -335,8 +353,9 @@ macro_rules! test_for_device_sqrt { fn sqrt_int() { let mut graph = Graph::empty(); let x = GraphTensor::, i32, $dev>::fill(&mut graph, 5); - let res = x.sqrt(); - let tensor = res.to_tensor().unwrap(); + let _res = x.sqrt(); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![2; 4]; 3],); } } From d2d219dc831e4086a7464e9569f2a323128ef8dd Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 20:31:32 +0000 Subject: [PATCH 08/18] Check shapes in compile --- constensor-core/src/error.rs | 116 +++++++++++++++++++++++++++++++++++ constensor-core/src/graph.rs | 17 +++++ constensor-core/src/lib.rs | 2 +- constensor-core/tests/ops.rs | 8 +-- 4 files changed, 138 insertions(+), 5 deletions(-) diff --git a/constensor-core/src/error.rs b/constensor-core/src/error.rs index 983affc..75d9cf6 100644 --- a/constensor-core/src/error.rs +++ b/constensor-core/src/error.rs @@ -1,3 +1,5 @@ +use std::{convert::Infallible, fmt::Display}; + #[derive(thiserror::Error, Debug)] pub enum Error { #[error(transparent)] @@ -14,11 +16,29 @@ pub enum Error { #[error("IO error: {0}")] IoError(String), + + /// Arbitrary errors wrapping. + #[error(transparent)] + Wrapped(Box), + + /// Arbitrary errors wrapping with context. + #[error("{wrapped:?}\n{context:?}")] + WrappedContext { + wrapped: Box, + context: String, + }, } pub type Result = std::result::Result; impl Error { + /// Create a new error based on a printable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. + pub fn msg(msg: M) -> Self { + Self::Msg(msg.to_string()).bt() + } + pub fn bt(self) -> Self { let backtrace = std::backtrace::Backtrace::capture(); match backtrace.status() { @@ -37,3 +57,99 @@ impl From for Error { Error::IoError(value.to_string()) } } + +#[macro_export] +macro_rules! bail { + ($msg:literal $(,)?) => { + return Err($crate::Error::Msg(format!($msg).into()).bt()) + }; + ($err:expr $(,)?) => { + return Err($crate::Error::Msg(format!($err).into()).bt()) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt()) + }; +} + +pub(crate) mod private { + pub trait Sealed {} + + impl Sealed for std::result::Result where E: std::error::Error {} + impl Sealed for Option {} +} + +/// Attach more context to an error. +/// +/// Inspired by [`anyhow::Context`]. +pub trait Context: private::Sealed { + /// Wrap the error value with additional context. + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context(self, f: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl Context for std::result::Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using map_err to save 2 useless frames off the captured backtrace + // in ext_context. + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context.to_string(), + }), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context().to_string(), + }), + } + } +} + +impl Context for Option { + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using ok_or_else to save 2 useless frames off the captured + // backtrace. + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context)), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context())), + } + } +} diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 6f00c67..152fc75 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -444,6 +444,23 @@ impl Graph { } pub fn compile(self) -> Result> { + if self + .data + .read() + .unwrap() + .last() + .is_some_and(|last| last.shape != S::shape()) + { + let read = self.data.read(); + let last = read.as_ref().unwrap().last().unwrap(); + + crate::bail!( + "Graph compiled shape is {:?} does not match the last node shape {:?}!", + &last.shape, + S::shape() + ); + } + let device = D::resolve()?; device.compile(self.data.read().unwrap().clone()) diff --git a/constensor-core/src/lib.rs b/constensor-core/src/lib.rs index c8dffc9..a614f3e 100644 --- a/constensor-core/src/lib.rs +++ b/constensor-core/src/lib.rs @@ -15,7 +15,7 @@ pub use device::Cpu; #[cfg(feature = "cuda")] pub use device::Cuda; pub use dtype::DType; -pub use error::{Error, Result}; +pub use error::{Context, Error, Result}; pub use graph::{CompiledGraph, Graph, GraphNode, Op}; pub use shape::{Shape, R1, R2, R3, R4, R5, R6}; pub use tensor::{GraphTensor, Tensor}; diff --git a/constensor-core/tests/ops.rs b/constensor-core/tests/ops.rs index c474b6e..fc7764b 100644 --- a/constensor-core/tests/ops.rs +++ b/constensor-core/tests/ops.rs @@ -63,7 +63,7 @@ macro_rules! test_for_device_float { let a = GraphTensor::, f32, $dev>::ones(&mut graph); let b = GraphTensor::, f32, $dev>::ones(&mut graph); let _c = a.matmul(b); - let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); let tensor = compiled.run().unwrap(); let expected: [Vec<[f32; 2]>; 1] = [vec![[3.0, 3.0], [3.0, 3.0]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); @@ -76,7 +76,7 @@ macro_rules! test_for_device_float { let b = GraphTensor::, f32, $dev>::ones(&mut graph); let o = GraphTensor::, f32, $dev>::ones(&mut graph); let _c = a.matmul_axpby(b, o, 1., 1.); - let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); let tensor = compiled.run().unwrap(); let expected: [Vec<[f32; 2]>; 1] = [vec![[4.0, 4.0], [4.0, 4.0]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); @@ -138,7 +138,7 @@ macro_rules! test_for_device_int { let a = GraphTensor::, i32, $dev>::ones(&mut graph); let b = GraphTensor::, i32, $dev>::ones(&mut graph); let _c = a.matmul(b); - let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); let tensor = compiled.run().unwrap(); let expected: [Vec<[i32; 2]>; 1] = [vec![[3, 3], [3, 3]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); @@ -151,7 +151,7 @@ macro_rules! test_for_device_int { let b = GraphTensor::, i32, $dev>::ones(&mut graph); let o = GraphTensor::, i32, $dev>::ones(&mut graph); let _c = a.matmul_axpby(b, o, 1, 1); - let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); let tensor = compiled.run().unwrap(); let expected: [Vec<[i32; 2]>; 1] = [vec![[4, 4], [4, 4]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); From df2011a1757402c12fbcce96f14bcd8c2df6c558 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 20:40:43 +0000 Subject: [PATCH 09/18] Fix clippy --- constensor-core/src/cuda_backend/mod.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 3780772..28b2821 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -115,8 +115,11 @@ fn handle_node( *header += &format!("T {} = {v:?};\n", name.to_name()); format!("({})", name.to_name()) } - Op::Arange { start, step, stop } => { - // todo!(); + Op::Arange { + start, + step, + stop: _, + } => { *current_name += 1; let name = Name(*current_name); *header += &format!( From f5f173af454cf08ea35fb92d0b90a593fc7ad4b9 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 21:31:28 +0000 Subject: [PATCH 10/18] Add gemm code --- .cargo/config.toml | 8 ++ constensor-core/src/cuda_backend/error.rs | 14 +- constensor-core/src/cuda_backend/mod.rs | 164 ++++++++++++++++------ constensor-core/src/cuda_backend/util.rs | 97 +++++++++++++ constensor-core/src/dtype/gemm.rs | 94 +++++++++++++ constensor-core/src/error.rs | 7 + 6 files changed, 332 insertions(+), 52 deletions(-) create mode 100644 .cargo/config.toml create mode 100644 constensor-core/src/cuda_backend/util.rs diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..ca9d853 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,8 @@ +[build] +rustflags = ["-C", "target-cpu=native"] + +[target.wasm32-unknown-unknown] +rustflags = ["-C", "target-feature=+simd128"] + +[target.x86_64-apple-darwin] +rustflags = ["-C", "target-feature=-avx,-avx2"] \ No newline at end of file diff --git a/constensor-core/src/cuda_backend/error.rs b/constensor-core/src/cuda_backend/error.rs index 6539d64..7d85f21 100644 --- a/constensor-core/src/cuda_backend/error.rs +++ b/constensor-core/src/cuda_backend/error.rs @@ -6,6 +6,12 @@ pub enum CudaError { #[error(transparent)] Cuda(#[from] cudarc::driver::DriverError), + #[error(transparent)] + Compiler(#[from] cudarc::nvrtc::CompileError), + + #[error(transparent)] + Cublas(#[from] cudarc::cublas::result::CublasError), + #[error("{cuda} when loading {module_name}")] Load { cuda: cudarc::driver::DriverError, @@ -34,11 +40,3 @@ impl> WrapErr for std::result::Result { self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt()) } } - -impl WrapErr for std::result::Result { - fn w(self) -> std::result::Result { - self.map_err(|e| { - crate::Error::Cuda(Box::new(CudaError::PtxCompileError { err: e }).into()).bt() - }) - } -} diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 28b2821..aba848e 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -1,3 +1,10 @@ +use cudarc::{ + cublas::CudaBlas, + driver::{CudaFunction, CudaModule, CudaSlice, LaunchConfig, PushKernelArg}, + nvrtc::{CompileOptions, Ptx}, +}; +use error::WrapErr; +use petgraph::{algo::toposort, prelude::DiGraphMap}; use std::{ borrow::Cow, collections::HashMap, @@ -8,13 +15,6 @@ use std::{ path::{Path, PathBuf}, sync::{Arc, RwLock}, }; -mod error; -use cudarc::{ - driver::{CudaFunction, CudaModule, CudaSlice, LaunchConfig, PushKernelArg}, - nvrtc::{CompileOptions, Ptx}, -}; -use error::WrapErr; -use petgraph::{algo::toposort, prelude::DiGraphMap}; use crate::{ cpu_storage::CpuStorage, @@ -23,6 +23,9 @@ use crate::{ CompiledGraph, DType, GraphNode, Op, Result, Shape, }; +pub(crate) mod error; +pub(crate) mod util; + #[derive(Clone)] pub struct CudaDevice { context: Arc, @@ -73,11 +76,24 @@ impl BackendStorage for CudaStorage { } } -pub struct CudaCompiledKernel { - func: CudaFunction, - slice: CudaSlice, - shape: Vec, - order: usize, +pub enum CudaCompiledKernel { + /// JIT‑compiled element‑wise kernel produced by `compile_kernel`. + ElementWise { + func: CudaFunction, + slice: CudaSlice, + shape: Vec, + order: usize, + }, + /// Matrix–multiplication kernel to be executed through cuBLAS. + MatMul { + l_id: usize, + r_id: usize, + b: usize, + m: usize, + n: usize, + k: usize, + order: usize, + }, } #[derive(Debug)] @@ -202,14 +218,14 @@ impl CudaDevice { &self, func: &CudaFunction, data: &CudaSlice, - shape: &Vec, + shape: &[usize], ) -> Result> { let n_elems: usize = shape.iter().product(); let stream = self.stream(); let cfg = LaunchConfig::for_num_elems(n_elems as u32); - let mut builder = stream.launch_builder(&func); + let mut builder = stream.launch_builder(func); builder.arg(data); builder.arg(&n_elems); unsafe { builder.launch(cfg).w()? }; @@ -337,29 +353,50 @@ impl BackendDevice for CudaDevice { // Compute topological order let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); - // Split into groups of nodes whose input shapes and dtype match + let mut kernels = Vec::>::new(); let mut splits: Vec<(Vec, Vec)> = Vec::new(); + for &idx in &order { - // Determine a key based on this node's input shapes - let shape_key: Vec = graph[idx].shape.clone(); - let should_group = if let Some((last_group, _)) = splits.last_mut() { - let last_idx = *last_group.last().unwrap(); - let last_shape_key = graph[last_idx].shape.clone(); - last_shape_key == shape_key - } else { - false - }; - if should_group { - splits.last_mut().unwrap().0.push(idx); - } else { - splits.push((vec![idx], shape_key)); + match &graph[idx].op { + // Give every MatMul its own split + Op::MatMul { l_id, r_id, .. } => { + let l_shape = &graph[l_id.get()].shape; + let r_shape = &graph[r_id.get()].shape; + assert_eq!(l_shape.len(), 3); + assert_eq!(r_shape.len(), 3); + let (b, m, k) = (l_shape[0], l_shape[1], l_shape[2]); + let n = r_shape[2]; + kernels.push(CudaCompiledKernel::MatMul { + l_id: l_id.get(), + r_id: r_id.get(), + b, + m, + n, + k, + order: idx, + }); + continue; // don’t add this node to an element‑wise split + } + _ => { + // existing element‑wise grouping logic + let shape_key = graph[idx].shape.clone(); + let should_group = if let Some((last_group, _)) = splits.last_mut() { + let last_idx = *last_group.last().unwrap(); + let last_shape_key = graph[last_idx].shape.clone(); + last_shape_key == shape_key + } else { + false + }; + if should_group { + splits.last_mut().unwrap().0.push(idx); + } else { + splits.push((vec![idx], shape_key)); + } + } } } - // For each group of nodes with matching input shapes/dtype, generate kernels - let mut kernels = Vec::new(); for (sub_order, shape) in splits { - // build header/body for this subgraph slice let mut header = String::new(); let body = handle_node( &mut 0, @@ -367,15 +404,14 @@ impl BackendDevice for CudaDevice { &graph[*sub_order.last().unwrap()], &graph, ); - // launch a kernel for this subgroup let (func, slice) = self.compile_kernel::(header.clone(), body.clone(), shape.clone())?; - kernels.push(CudaCompiledKernel { + kernels.push(CudaCompiledKernel::ElementWise { func, slice, shape, order: *sub_order.iter().max().unwrap(), - }) + }); } Ok(CompiledGraph::Cuda { @@ -396,16 +432,56 @@ impl BackendDevice for CudaDevice { // For each group of nodes with matching input shapes/dtype, generate and run kernels let mut last_storage = HashMap::new(); - for CudaCompiledKernel { - func, - slice, - shape, - order, - } in kernels - { - // launch a kernel for this subgroup - let storage = self.run_kernel::(func, slice, shape)?; - last_storage.insert(order, storage); + for kernel in kernels { + match kernel { + CudaCompiledKernel::ElementWise { + func, + slice, + shape, + order, + } => { + let storage = self.run_kernel::(func, slice, shape)?; + last_storage.insert(order, storage); + } + CudaCompiledKernel::MatMul { + l_id, + r_id, + b, + m, + n, + k, + order, + } => { + // obtain input buffers + let lhs = last_storage.get(&l_id).expect("lhs storage missing"); + let rhs = last_storage.get(&r_id).expect("rhs storage missing"); + + // allocate output buffer + let elems = { *m } * { *n }; + let mut out = unsafe { self.stream().alloc::(elems) }.w()?; + + let cublas = CudaBlas::new(self.stream()).unwrap(); + + T::launch_gemm_cuda( + cublas, + &lhs.slice, + &rhs.slice, + *b, + *m, + *n, + *k, + &mut out, + T::ZERO, + T::ONE, + )?; + + let storage = CudaStorage { + slice: out, + device: self.clone(), + }; + last_storage.insert(order, storage); + } + } } let key = *last_storage.keys().max().unwrap(); diff --git a/constensor-core/src/cuda_backend/util.rs b/constensor-core/src/cuda_backend/util.rs new file mode 100644 index 0000000..9ac05ff --- /dev/null +++ b/constensor-core/src/cuda_backend/util.rs @@ -0,0 +1,97 @@ +use cudarc::cublas::{GemmConfig, StridedBatchedConfig}; + +use crate::{Error, Result}; + +pub(crate) fn gemm_config( + alpha: T, + beta: T, + (b, m, n, k): (usize, usize, usize, usize), +) -> Result> { + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm + use cudarc::cublas::sys::cublasOperation_t; + + let lhs_stride = [m * k, k, 1]; + let rhs_stride = [k * n, n, 1]; + let lhs_dims = [b, m, k]; + let rhs_dims = [b, k, n]; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // The a tensor has dims batching, k, n (rhs) + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, cublasOperation_t::CUBLAS_OP_N) + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { + (k as i32, cublasOperation_t::CUBLAS_OP_T) + } else { + Err(Error::MatMulNonContiguous { + lhs_stride, + rhs_stride, + mnk: (m, n, k), + })? + }; + // The b tensor has dims batching, m, k (lhs) + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, cublasOperation_t::CUBLAS_OP_N) + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { + (m as i32, cublasOperation_t::CUBLAS_OP_T) + } else { + Err(Error::MatMulNonContiguous { + lhs_stride, + rhs_stride, + mnk: (m, n, k), + })? + }; + // The setup below was copied from: + // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 + let gemm = GemmConfig { + alpha, + beta, + m: n as i32, + n: m as i32, + k: k as i32, + lda, + ldb, + ldc: n as i32, + transa, + transb, + }; + + let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { + [s1, stride] if s1 == stride * lhs_dims[1] => stride, + [_, stride] if lhs_dims[0] == 1 => stride, + [stride, _] if lhs_dims[1] == 1 => stride, + [stride] => stride, + [] => m * k, + _ => Err(Error::MatMulNonContiguous { + lhs_stride, + rhs_stride, + mnk: (m, n, k), + })?, + }; + let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { + [s1, stride] if s1 == stride * rhs_dims[1] => stride, + [_, stride] if rhs_dims[0] == 1 => stride, + [stride, _] if rhs_dims[1] == 1 => stride, + [stride] => stride, + [] => n * k, + _ => Err(Error::MatMulNonContiguous { + lhs_stride, + rhs_stride, + mnk: (m, n, k), + })?, + }; + + Ok(StridedBatchedConfig { + batch_size: b as i32, + gemm, + stride_a: stride_a as i64, + stride_b: stride_b as i64, + stride_c: (m * n) as i64, + }) +} diff --git a/constensor-core/src/dtype/gemm.rs b/constensor-core/src/dtype/gemm.rs index 1bff148..0dd4a6b 100644 --- a/constensor-core/src/dtype/gemm.rs +++ b/constensor-core/src/dtype/gemm.rs @@ -23,6 +23,94 @@ pub trait GemmDispatch { beta: Self, ) where Self: Sized; + + #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] + // Matrix multiplication: (B x M x K) * (B x K x N) = (B x M x N) + fn launch_gemm_cuda( + cublas: cudarc::cublas::CudaBlas, + lhs: &cudarc::driver::CudaSlice, + rhs: &cudarc::driver::CudaSlice, + b: usize, + m: usize, + n: usize, + k: usize, + out: &mut cudarc::driver::CudaSlice, + alpha: Self, + beta: Self, + ) -> crate::Result<()> + where + Self: Sized; +} + +macro_rules! instantiate_gemm_cuda { + (u8) => { + instantiate_gemm_cuda!(__instantiate_fail); + }; + (u32) => { + instantiate_gemm_cuda!(__instantiate_fail); + }; + (i32) => { + instantiate_gemm_cuda!(__instantiate_fail); + }; + (i64) => { + instantiate_gemm_cuda!(__instantiate_fail); + }; + + (__instantiate_fail) => { + #[cfg(feature = "cuda")] + fn launch_gemm_cuda( + _cublas: cudarc::cublas::CudaBlas, + _lhs: &cudarc::driver::CudaSlice, + _rhs: &cudarc::driver::CudaSlice, + _b: usize, + _m: usize, + _n: usize, + _k: usize, + _out: &mut cudarc::driver::CudaSlice, + _alpha: Self, + _beta: Self, + ) -> crate::Result<()> + where + Self: Sized, + { + panic!("`launch_gemm_cuda` called with invalid configuration (w/o CUDA, dtype)") + } + }; + + ($rt:ident) => { + #[cfg(feature = "cuda")] + fn launch_gemm_cuda( + cublas: cudarc::cublas::CudaBlas, + lhs: &cudarc::driver::CudaSlice<$rt>, + rhs: &cudarc::driver::CudaSlice<$rt>, + b: usize, + m: usize, + n: usize, + k: usize, + out: &mut cudarc::driver::CudaSlice<$rt>, + alpha: $rt, + beta: $rt, + ) -> crate::Result<()> { + use crate::cuda_backend::error::WrapErr; + use cudarc::cublas::Gemm; + + let gemm_cfg = crate::cuda_backend::util::gemm_config(alpha, beta, (b, m, n, k))?; + + unsafe { + cublas + .gemm_strided_batched( + gemm_cfg, + &lhs.as_view(), + &rhs.as_view(), + &mut out.as_view_mut(), + ) + .w()?; + } + + Ok(()) + } + }; } macro_rules! instantiate_gemm { @@ -54,6 +142,8 @@ macro_rules! instantiate_gemm { } } } + + instantiate_gemm_cuda!($rt); } }; @@ -121,6 +211,8 @@ macro_rules! instantiate_gemm { } } } + + instantiate_gemm_cuda!($rt); } }; // SIMD-accelerated gemm using SimdSupported for vectorized operations along 'n' dimension @@ -218,6 +310,8 @@ macro_rules! instantiate_gemm { } } } + + instantiate_gemm_cuda!($rt); } }; } diff --git a/constensor-core/src/error.rs b/constensor-core/src/error.rs index 75d9cf6..8ed3fab 100644 --- a/constensor-core/src/error.rs +++ b/constensor-core/src/error.rs @@ -27,6 +27,13 @@ pub enum Error { wrapped: Box, context: String, }, + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: [usize; 3], + rhs_stride: [usize; 3], + mnk: (usize, usize, usize), + }, } pub type Result = std::result::Result; From 70254741ee8b0988160c9341ff46735ec5115dec Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 24 Apr 2025 21:35:55 +0000 Subject: [PATCH 11/18] Seperate splits --- constensor-core/src/cuda_backend/mod.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index aba848e..66e6c13 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -7,7 +7,7 @@ use error::WrapErr; use petgraph::{algo::toposort, prelude::DiGraphMap}; use std::{ borrow::Cow, - collections::HashMap, + collections::{HashMap, HashSet}, fs, hash::{DefaultHasher, Hash, Hasher}, marker::PhantomData, @@ -353,12 +353,21 @@ impl BackendDevice for CudaDevice { // Compute topological order let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); + // New kernel and grouping logic with matmul input tracking let mut kernels = Vec::>::new(); + let mut matmuls = Vec::>::new(); let mut splits: Vec<(Vec, Vec)> = Vec::new(); + // Collect all matmul input node indices + let mut matmul_inputs = HashSet::new(); + for &idx in &order { + if let Op::MatMul { l_id, r_id, .. } = &graph[idx].op { + matmul_inputs.insert(l_id.get()); + matmul_inputs.insert(r_id.get()); + } + } for &idx in &order { match &graph[idx].op { - // Give every MatMul its own split Op::MatMul { l_id, r_id, .. } => { let l_shape = &graph[l_id.get()].shape; let r_shape = &graph[r_id.get()].shape; @@ -366,7 +375,7 @@ impl BackendDevice for CudaDevice { assert_eq!(r_shape.len(), 3); let (b, m, k) = (l_shape[0], l_shape[1], l_shape[2]); let n = r_shape[2]; - kernels.push(CudaCompiledKernel::MatMul { + matmuls.push(CudaCompiledKernel::MatMul { l_id: l_id.get(), r_id: r_id.get(), b, @@ -375,15 +384,13 @@ impl BackendDevice for CudaDevice { k, order: idx, }); - continue; // don’t add this node to an element‑wise split } _ => { - // existing element‑wise grouping logic let shape_key = graph[idx].shape.clone(); let should_group = if let Some((last_group, _)) = splits.last_mut() { let last_idx = *last_group.last().unwrap(); let last_shape_key = graph[last_idx].shape.clone(); - last_shape_key == shape_key + last_shape_key == shape_key && !matmul_inputs.contains(&idx) } else { false }; @@ -396,6 +403,7 @@ impl BackendDevice for CudaDevice { } } + // Compile element‑wise splits first so matmul inputs are ready for (sub_order, shape) in splits { let mut header = String::new(); let body = handle_node( @@ -413,6 +421,8 @@ impl BackendDevice for CudaDevice { order: *sub_order.iter().max().unwrap(), }); } + // Then append all MatMul kernels + kernels.extend(matmuls); Ok(CompiledGraph::Cuda { kernels, From 5c5f5a190c56666ecbe470f132eed5baae20ce35 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 01:35:44 +0000 Subject: [PATCH 12/18] Fix axpby --- constensor-core/src/cuda_backend/mod.rs | 46 ++++++++++++++++--------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 66e6c13..09dbefd 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -88,11 +88,17 @@ pub enum CudaCompiledKernel { MatMul { l_id: usize, r_id: usize, + /// Optional output tensor ID for axpby semantics + o_id: Option, b: usize, m: usize, n: usize, k: usize, order: usize, + /// scale factor for existing output + alpha: T, + /// scale factor for lhs*rhs + beta: T, }, } @@ -368,21 +374,31 @@ impl BackendDevice for CudaDevice { for &idx in &order { match &graph[idx].op { - Op::MatMul { l_id, r_id, .. } => { + Op::MatMul { + l_id, + r_id, + o_id, + k, + alpha, + beta, + } => { let l_shape = &graph[l_id.get()].shape; let r_shape = &graph[r_id.get()].shape; assert_eq!(l_shape.len(), 3); assert_eq!(r_shape.len(), 3); - let (b, m, k) = (l_shape[0], l_shape[1], l_shape[2]); + let (b, m, _k) = (l_shape[0], l_shape[1], l_shape[2]); let n = r_shape[2]; matmuls.push(CudaCompiledKernel::MatMul { l_id: l_id.get(), r_id: r_id.get(), + o_id: o_id.as_ref().map(|id| id.get()), b, m, n, - k, + k: *k, order: idx, + alpha: *alpha, + beta: *beta, }); } _ => { @@ -456,33 +472,31 @@ impl BackendDevice for CudaDevice { CudaCompiledKernel::MatMul { l_id, r_id, + o_id, b, m, n, k, order, + alpha, + beta, } => { // obtain input buffers let lhs = last_storage.get(&l_id).expect("lhs storage missing"); let rhs = last_storage.get(&r_id).expect("rhs storage missing"); - // allocate output buffer - let elems = { *m } * { *n }; + let elems = m * n; + // prepare output buffer, copy initial if provided let mut out = unsafe { self.stream().alloc::(elems) }.w()?; + if let Some(o_idx) = o_id { + let init = last_storage.get(&o_idx).expect("output storage missing"); + self.stream().memcpy_dtod(&init.slice, &mut out).w()?; + } let cublas = CudaBlas::new(self.stream()).unwrap(); - + // Note: cublas expects (alpha: product scale, beta: output scale) T::launch_gemm_cuda( - cublas, - &lhs.slice, - &rhs.slice, - *b, - *m, - *n, - *k, - &mut out, - T::ZERO, - T::ONE, + cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha, )?; let storage = CudaStorage { From 2500c7ee7d47d4c2943dd7caab549a5df2b2f314 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 01:50:36 +0000 Subject: [PATCH 13/18] All tests pass --- constensor-core/src/cuda_backend/mod.rs | 29 ++++++++++--------------- constensor-core/tests/ops.rs | 2 ++ 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 09dbefd..95169d0 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -213,6 +213,7 @@ fn compile_ptx(template_kernel: String) -> Result { .join("include") .display() .to_string()], + arch: Some("sm_90"), ..Default::default() }, ) @@ -279,22 +280,8 @@ impl CudaDevice { T::C_NAME, ); - let ptx = if let Some(home) = dirs::home_dir() { - let path = format!( - "{}/.cache/constensor/ptx/{function_name}.ptx", - home.display() - ); - if Path::new(&path).exists() { - match fs::read_to_string(path) { - Ok(ptx) => Ptx::from_src(ptx), - Err(_) => compile_ptx(template_kernel)?, - } - } else { - compile_ptx(template_kernel)? - } - } else { - compile_ptx(template_kernel)? - }; + // Always recompile PTX to avoid using stale cached files + let ptx = compile_ptx(template_kernel.clone())?; let ptx_str = ptx.to_src(); if let Some(home) = dirs::home_dir() { @@ -366,9 +353,15 @@ impl BackendDevice for CudaDevice { // Collect all matmul input node indices let mut matmul_inputs = HashSet::new(); for &idx in &order { - if let Op::MatMul { l_id, r_id, .. } = &graph[idx].op { + if let Op::MatMul { + l_id, r_id, o_id, .. + } = &graph[idx].op + { matmul_inputs.insert(l_id.get()); matmul_inputs.insert(r_id.get()); + if let Some(o_id) = o_id { + matmul_inputs.insert(o_id.get()); + } } } @@ -406,7 +399,7 @@ impl BackendDevice for CudaDevice { let should_group = if let Some((last_group, _)) = splits.last_mut() { let last_idx = *last_group.last().unwrap(); let last_shape_key = graph[last_idx].shape.clone(); - last_shape_key == shape_key && !matmul_inputs.contains(&idx) + last_shape_key == shape_key } else { false }; diff --git a/constensor-core/tests/ops.rs b/constensor-core/tests/ops.rs index fc7764b..f4b7dad 100644 --- a/constensor-core/tests/ops.rs +++ b/constensor-core/tests/ops.rs @@ -132,6 +132,7 @@ macro_rules! test_for_device_int { assert_eq!(tensor.data().unwrap().to_vec(), vec![1, 2, 3, 4]); } + #[cfg(not(feature = "cuda"))] #[test] fn matmul() { let mut graph = Graph::empty(); @@ -144,6 +145,7 @@ macro_rules! test_for_device_int { assert_eq!(tensor.data().unwrap().to_vec(), expected); } + #[cfg(not(feature = "cuda"))] #[test] fn matmul_axpby() { let mut graph = Graph::empty(); From 7995af0da9d53870f74bef5959bf7eb7fd8c597d Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 01:53:25 +0000 Subject: [PATCH 14/18] Don't recreate cublas all the time --- constensor-core/src/cuda_backend/mod.rs | 10 ++++++++-- constensor-core/src/dtype/gemm.rs | 6 +++--- constensor-core/src/graph.rs | 1 + 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 95169d0..554979d 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -433,8 +433,11 @@ impl BackendDevice for CudaDevice { // Then append all MatMul kernels kernels.extend(matmuls); + let cublas = CudaBlas::new(self.stream()).unwrap(); + Ok(CompiledGraph::Cuda { kernels, + cublas, ghost: PhantomData, }) } @@ -444,7 +447,11 @@ impl BackendDevice for CudaDevice { graph: &CompiledGraph, ) -> Result> { #[allow(irrefutable_let_patterns)] - let CompiledGraph::Cuda { kernels, ghost: _ } = graph + let CompiledGraph::Cuda { + kernels, + cublas, + ghost: _, + } = graph else { unreachable!() }; @@ -486,7 +493,6 @@ impl BackendDevice for CudaDevice { self.stream().memcpy_dtod(&init.slice, &mut out).w()?; } - let cublas = CudaBlas::new(self.stream()).unwrap(); // Note: cublas expects (alpha: product scale, beta: output scale) T::launch_gemm_cuda( cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha, diff --git a/constensor-core/src/dtype/gemm.rs b/constensor-core/src/dtype/gemm.rs index 0dd4a6b..57c22f1 100644 --- a/constensor-core/src/dtype/gemm.rs +++ b/constensor-core/src/dtype/gemm.rs @@ -28,7 +28,7 @@ pub trait GemmDispatch { #[allow(clippy::too_many_arguments)] // Matrix multiplication: (B x M x K) * (B x K x N) = (B x M x N) fn launch_gemm_cuda( - cublas: cudarc::cublas::CudaBlas, + cublas: &cudarc::cublas::CudaBlas, lhs: &cudarc::driver::CudaSlice, rhs: &cudarc::driver::CudaSlice, b: usize, @@ -60,7 +60,7 @@ macro_rules! instantiate_gemm_cuda { (__instantiate_fail) => { #[cfg(feature = "cuda")] fn launch_gemm_cuda( - _cublas: cudarc::cublas::CudaBlas, + _cublas: &cudarc::cublas::CudaBlas, _lhs: &cudarc::driver::CudaSlice, _rhs: &cudarc::driver::CudaSlice, _b: usize, @@ -81,7 +81,7 @@ macro_rules! instantiate_gemm_cuda { ($rt:ident) => { #[cfg(feature = "cuda")] fn launch_gemm_cuda( - cublas: cudarc::cublas::CudaBlas, + cublas: &cudarc::cublas::CudaBlas, lhs: &cudarc::driver::CudaSlice<$rt>, rhs: &cudarc::driver::CudaSlice<$rt>, b: usize, diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 152fc75..d675dda 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -477,6 +477,7 @@ pub enum CompiledGraph { #[cfg(feature = "cuda")] Cuda { kernels: Vec>, + cublas: cudarc::cublas::CudaBlas, ghost: PhantomData<(S, T, D)>, }, } From 5afa1abd19821cf06f04cf707256f1d4d0293e42 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 02:04:54 +0000 Subject: [PATCH 15/18] Use multiple streams --- constensor-core/src/cuda_backend/mod.rs | 67 ++++++++++++++++++++----- constensor-core/src/graph.rs | 1 - 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 554979d..07ba1a2 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -1,10 +1,13 @@ use cudarc::{ cublas::CudaBlas, - driver::{CudaFunction, CudaModule, CudaSlice, LaunchConfig, PushKernelArg}, + driver::{ + CudaEvent, CudaFunction, CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg, + }, nvrtc::{CompileOptions, Ptx}, }; use error::WrapErr; use petgraph::{algo::toposort, prelude::DiGraphMap}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::{ borrow::Cow, collections::{HashMap, HashSet}, @@ -31,19 +34,36 @@ pub struct CudaDevice { context: Arc, stream: Arc, modules: Arc>>>, + streams: Arc>>, + stream_index: Arc, } impl CudaDevice { pub(crate) fn new(ordinal: usize) -> Result { let context = cudarc::driver::CudaContext::new(ordinal).w()?; let stream = context.new_stream().w()?; + // Create a pool of 8 streams for concurrent kernel execution + let mut pool = Vec::with_capacity(8); + for _ in 0..8 { + pool.push(context.new_stream().w()?); + } + let streams = Arc::new(pool); + let stream_index = Arc::new(AtomicUsize::new(0)); Ok(Self { context, stream, modules: Arc::new(RwLock::new(vec![])), + streams, + stream_index, }) } + /// Round-robin selection of a stream from the pool + fn select_stream(&self) -> Arc { + let idx = self.stream_index.fetch_add(1, Ordering::SeqCst) % self.streams.len(); + self.streams[idx].clone() + } + pub(crate) fn stream(&self) -> Arc { self.stream.clone() } @@ -67,6 +87,7 @@ impl Deref for CudaDevice { pub struct CudaStorage { slice: CudaSlice, device: CudaDevice, + event: CudaEvent, } impl BackendStorage for CudaStorage { @@ -99,6 +120,8 @@ pub enum CudaCompiledKernel { alpha: T, /// scale factor for lhs*rhs beta: T, + cublas: cudarc::cublas::CudaBlas, + stream: Arc, }, } @@ -228,7 +251,7 @@ impl CudaDevice { shape: &[usize], ) -> Result> { let n_elems: usize = shape.iter().product(); - let stream = self.stream(); + let stream = self.select_stream(); let cfg = LaunchConfig::for_num_elems(n_elems as u32); @@ -237,9 +260,14 @@ impl CudaDevice { builder.arg(&n_elems); unsafe { builder.launch(cfg).w()? }; + // Record an event once this kernel completes + let event = self.context.new_event(None).w()?; + event.record(&stream).w()?; + Ok(CudaStorage { slice: data.clone(), device: self.clone(), + event, }) } @@ -381,6 +409,11 @@ impl BackendDevice for CudaDevice { assert_eq!(r_shape.len(), 3); let (b, m, _k) = (l_shape[0], l_shape[1], l_shape[2]); let n = r_shape[2]; + + // Select our stream + let stream = self.select_stream(); + let cublas = CudaBlas::new(stream.clone()).unwrap(); + matmuls.push(CudaCompiledKernel::MatMul { l_id: l_id.get(), r_id: r_id.get(), @@ -392,6 +425,8 @@ impl BackendDevice for CudaDevice { order: idx, alpha: *alpha, beta: *beta, + cublas, + stream, }); } _ => { @@ -433,11 +468,8 @@ impl BackendDevice for CudaDevice { // Then append all MatMul kernels kernels.extend(matmuls); - let cublas = CudaBlas::new(self.stream()).unwrap(); - Ok(CompiledGraph::Cuda { kernels, - cublas, ghost: PhantomData, }) } @@ -447,11 +479,7 @@ impl BackendDevice for CudaDevice { graph: &CompiledGraph, ) -> Result> { #[allow(irrefutable_let_patterns)] - let CompiledGraph::Cuda { - kernels, - cublas, - ghost: _, - } = graph + let CompiledGraph::Cuda { kernels, ghost: _ } = graph else { unreachable!() }; @@ -480,27 +508,40 @@ impl BackendDevice for CudaDevice { order, alpha, beta, + cublas, + stream, } => { // obtain input buffers let lhs = last_storage.get(&l_id).expect("lhs storage missing"); let rhs = last_storage.get(&r_id).expect("rhs storage missing"); + // Wait for prior kernels + lhs.event.synchronize().w()?; + rhs.event.synchronize().w()?; + let elems = m * n; // prepare output buffer, copy initial if provided - let mut out = unsafe { self.stream().alloc::(elems) }.w()?; + let mut out = unsafe { stream.alloc::(elems) }.w()?; if let Some(o_idx) = o_id { let init = last_storage.get(&o_idx).expect("output storage missing"); + // ensure the initial output is ready + init.event.synchronize().w()?; self.stream().memcpy_dtod(&init.slice, &mut out).w()?; } - // Note: cublas expects (alpha: product scale, beta: output scale) + // Launch GEMM on the pooled stream T::launch_gemm_cuda( - cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha, + &cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha, )?; + // Record completion event for the MatMul result + let event = self.context.new_event(None).w()?; + event.record(&stream).w()?; + let storage = CudaStorage { slice: out, device: self.clone(), + event, }; last_storage.insert(order, storage); } diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index d675dda..152fc75 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -477,7 +477,6 @@ pub enum CompiledGraph { #[cfg(feature = "cuda")] Cuda { kernels: Vec>, - cublas: cudarc::cublas::CudaBlas, ghost: PhantomData<(S, T, D)>, }, } From 824dbcb14c959b5d2ae66fb65bc39dd2ef75f11f Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 02:40:16 +0000 Subject: [PATCH 16/18] Fixes --- constensor-core/src/cuda_backend/mod.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 07ba1a2..ebbf94e 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -434,7 +434,8 @@ impl BackendDevice for CudaDevice { let should_group = if let Some((last_group, _)) = splits.last_mut() { let last_idx = *last_group.last().unwrap(); let last_shape_key = graph[last_idx].shape.clone(); - last_shape_key == shape_key + // Force all matmul inputs to have their own + last_shape_key == shape_key && !matmul_inputs.contains(&idx) } else { false }; @@ -531,12 +532,12 @@ impl BackendDevice for CudaDevice { // Launch GEMM on the pooled stream T::launch_gemm_cuda( - &cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha, + cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha, )?; // Record completion event for the MatMul result let event = self.context.new_event(None).w()?; - event.record(&stream).w()?; + event.record(stream).w()?; let storage = CudaStorage { slice: out, From 89eb4aaa3e8f07b457420bafac6598e97074432b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 02:43:49 +0000 Subject: [PATCH 17/18] Update config toml --- .cargo/config.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.cargo/config.toml b/.cargo/config.toml index ca9d853..4abaaa8 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,3 +1,7 @@ +[target.x86_64-unknown-linux-gnu] +rustflags = ["-C", "target-cpu=native"] + +[target.aarch64-apple-darwin] [build] rustflags = ["-C", "target-cpu=native"] From dee07a790c3c01910d1a2106c867e19f2bf03d51 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 25 Apr 2025 02:45:16 +0000 Subject: [PATCH 18/18] Update config toml --- .cargo/config.toml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 4abaaa8..c7a111c 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,9 +1,15 @@ [target.x86_64-unknown-linux-gnu] -rustflags = ["-C", "target-cpu=native"] +rustflags = [ + "-C", "target-cpu=native", + "-C", "target-feature=+fp16" +] [target.aarch64-apple-darwin] [build] -rustflags = ["-C", "target-cpu=native"] +rustflags = [ + "-C", "target-cpu=native", + "-C", "target-feature=+fp16" +] [target.wasm32-unknown-unknown] rustflags = ["-C", "target-feature=+simd128"]