diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..c7a111c --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,18 @@ +[target.x86_64-unknown-linux-gnu] +rustflags = [ + "-C", "target-cpu=native", + "-C", "target-feature=+fp16" +] + +[target.aarch64-apple-darwin] +[build] +rustflags = [ + "-C", "target-cpu=native", + "-C", "target-feature=+fp16" +] + +[target.wasm32-unknown-unknown] +rustflags = ["-C", "target-feature=+simd128"] + +[target.x86_64-apple-darwin] +rustflags = ["-C", "target-feature=-avx,-avx2"] \ No newline at end of file diff --git a/constensor-core/examples/hello_world/main.rs b/constensor-core/examples/hello_world/main.rs index f167c8a..372476e 100644 --- a/constensor-core/examples/hello_world/main.rs +++ b/constensor-core/examples/hello_world/main.rs @@ -2,20 +2,22 @@ use constensor_core::{Cpu, Graph, GraphTensor, Tensor, R1, R2}; fn main() { let mut graph: Graph = Graph::empty(); - let arange = GraphTensor::, f32, Cpu>::arange(&mut graph, 0., 1.); - dbg!(&arange.to_tensor().unwrap().data()); + let _arange = GraphTensor::, f32, Cpu>::arange(&mut graph, 0., 1.); let a = GraphTensor::, f32, Cpu>::fill(&mut graph, 1.0); let b = GraphTensor::, f32, Cpu>::fill(&mut graph, 2.0); let c = GraphTensor::, f32, Cpu>::fill(&mut graph, 3.0); let d = GraphTensor::, f32, Cpu>::fill(&mut graph, 4.0); let res = a * b + c; - let res = res + d; + let _out = res + d; graph.optimize(); graph.visualize("graph.png").unwrap(); - let tensor: Tensor, f32, Cpu> = res.to_tensor().unwrap(); + let compiled: constensor_core::CompiledGraph, f32, Cpu> = graph.compile().unwrap(); + let res = compiled.run().unwrap(); + + let tensor: Tensor, f32, Cpu> = res; assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![9.0; 4]; 3],); } diff --git a/constensor-core/examples/matmul/main.rs b/constensor-core/examples/matmul/main.rs index ce2897f..8dceab2 100644 --- a/constensor-core/examples/matmul/main.rs +++ b/constensor-core/examples/matmul/main.rs @@ -1,4 +1,4 @@ -use constensor_core::{Cpu, DType, Graph, GraphTensor, R3}; +use constensor_core::{CompiledGraph, Cpu, DType, Graph, GraphTensor, R3}; use std::time::Instant; fn bench( @@ -10,18 +10,19 @@ fn bench, T, Cpu>::ones(&mut graph); + let b = GraphTensor::, T, Cpu>::ones(&mut graph); + let o = GraphTensor::, T, Cpu>::ones(&mut graph); + let _c = a.matmul_axpby(b, o, alpha, beta); - let mut graph = Graph::empty(); - let a = GraphTensor::, T, Cpu>::ones(&mut graph); - let b = GraphTensor::, T, Cpu>::ones(&mut graph); - let o = GraphTensor::, T, Cpu>::ones(&mut graph); - let c = a.matmul_axpby(b, o, alpha, beta); + graph.optimize(); + let compiled: CompiledGraph, T, Cpu> = graph.compile().unwrap(); - graph.optimize(); + for _ in 0..iterations { + let start = Instant::now(); - let _tensor = std::hint::black_box(c.to_tensor().unwrap()); + let _tensor = std::hint::black_box(compiled.run().unwrap()); total += start.elapsed(); } diff --git a/constensor-core/src/cpu_storage/mod.rs b/constensor-core/src/cpu_storage/mod.rs index 3dd73c1..9cdeb7d 100644 --- a/constensor-core/src/cpu_storage/mod.rs +++ b/constensor-core/src/cpu_storage/mod.rs @@ -1,15 +1,17 @@ use petgraph::algo::toposort; use petgraph::graphmap::DiGraphMap; -use std::borrow::Cow; use std::cell::RefCell; use std::rc::Rc; +use std::{borrow::Cow, marker::PhantomData}; use pool::{BufferPool, PooledBuffer}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +use crate::device::Dev; +use crate::Shape; use crate::{ storage::{BackendDevice, BackendStorage}, - DType, GraphNode, Op, Result, + CompiledGraph, DType, GraphNode, Op, Result, }; mod pool; @@ -29,49 +31,73 @@ impl BackendStorage for CpuStorage { impl BackendDevice for CpuDevice { type Storage = CpuStorage; - fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { - { - // Create a shared buffer pool - let pool = Rc::new(RefCell::new(BufferPool::::new())); - - // Build a dependency graph of tensor indices - let mut dep_graph = DiGraphMap::::new(); - for idx in 0..graph.len() { - dep_graph.add_node(idx); - } + fn compile( + &self, + graph: Vec>, + ) -> Result> { + // Build a dependency graph of tensor indices + let mut dep_graph = DiGraphMap::::new(); + for idx in 0..graph.len() { + dep_graph.add_node(idx); + } - for (idx, node) in graph.iter().enumerate() { - match &node.op { - Op::BinaryOp { l_id, r_id, .. } => { - dep_graph.add_edge(l_id.get(), idx, ()); - dep_graph.add_edge(r_id.get(), idx, ()); - } - Op::UnaryOp { v_id, .. } => { - dep_graph.add_edge(v_id.get(), idx, ()); - } - Op::FusedMulAdd { a_id, b_id, c_id } => { - dep_graph.add_edge(a_id.get(), idx, ()); - dep_graph.add_edge(b_id.get(), idx, ()); - dep_graph.add_edge(c_id.get(), idx, ()); - } - Op::MatMul { l_id, r_id, .. } => { - dep_graph.add_edge(l_id.get(), idx, ()); - dep_graph.add_edge(r_id.get(), idx, ()); - } - // NoOp and Fill/Arange don’t create incoming edges - Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} + for (idx, node) in graph.iter().enumerate() { + match &node.op { + Op::BinaryOp { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); } + Op::UnaryOp { v_id, .. } => { + dep_graph.add_edge(v_id.get(), idx, ()); + } + Op::FusedMulAdd { a_id, b_id, c_id } => { + dep_graph.add_edge(a_id.get(), idx, ()); + dep_graph.add_edge(b_id.get(), idx, ()); + dep_graph.add_edge(c_id.get(), idx, ()); + } + Op::MatMul { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); + } + // NoOp and Fill/Arange don’t create incoming edges + Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} } + } + + // Compute topological order + let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); + + Ok(CompiledGraph::Cpu { + order, + graph, + ghost: PhantomData, + }) + } + + fn run_graph( + &self, + graph: &CompiledGraph, + ) -> Result> { + { + // Create a shared buffer pool + let pool = Rc::new(RefCell::new(BufferPool::::new())); - // Compute topological order - let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); + #[allow(irrefutable_let_patterns)] + let CompiledGraph::Cpu { + order, + graph, + ghost: _, + } = graph + else { + unreachable!() + }; // Prepare storage for intermediate results let mut results: Vec>> = Vec::with_capacity(graph.len()); results.resize_with(graph.len(), || None); // Evaluate nodes in topological order - for idx in order { + for idx in order.clone() { let op = &graph[idx]; let out_shape = &op.shape; diff --git a/constensor-core/src/cuda_backend/error.rs b/constensor-core/src/cuda_backend/error.rs index 6539d64..7d85f21 100644 --- a/constensor-core/src/cuda_backend/error.rs +++ b/constensor-core/src/cuda_backend/error.rs @@ -6,6 +6,12 @@ pub enum CudaError { #[error(transparent)] Cuda(#[from] cudarc::driver::DriverError), + #[error(transparent)] + Compiler(#[from] cudarc::nvrtc::CompileError), + + #[error(transparent)] + Cublas(#[from] cudarc::cublas::result::CublasError), + #[error("{cuda} when loading {module_name}")] Load { cuda: cudarc::driver::DriverError, @@ -34,11 +40,3 @@ impl> WrapErr for std::result::Result { self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt()) } } - -impl WrapErr for std::result::Result { - fn w(self) -> std::result::Result { - self.map_err(|e| { - crate::Error::Cuda(Box::new(CudaError::PtxCompileError { err: e }).into()).bt() - }) - } -} diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index 47561a2..ebbf94e 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -1,53 +1,78 @@ +use cudarc::{ + cublas::CudaBlas, + driver::{ + CudaEvent, CudaFunction, CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg, + }, + nvrtc::{CompileOptions, Ptx}, +}; +use error::WrapErr; +use petgraph::{algo::toposort, prelude::DiGraphMap}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::{ borrow::Cow, - cell::OnceCell, + collections::{HashMap, HashSet}, fs, hash::{DefaultHasher, Hash, Hasher}, + marker::PhantomData, ops::Deref, path::{Path, PathBuf}, - sync::Arc, -}; -mod error; -use cudarc::{ - driver::{CudaFunction, CudaModule, CudaSlice, LaunchConfig, PushKernelArg}, - nvrtc::{CompileOptions, Ptx}, + sync::{Arc, RwLock}, }; -use error::WrapErr; use crate::{ cpu_storage::CpuStorage, - graph::GraphTensorId, + device::Dev, storage::{BackendDevice, BackendStorage}, - DType, Op, Result, SignedDType, + CompiledGraph, DType, GraphNode, Op, Result, Shape, }; +pub(crate) mod error; +pub(crate) mod util; + #[derive(Clone)] pub struct CudaDevice { context: Arc, stream: Arc, - module: OnceCell>, + modules: Arc>>>, + streams: Arc>>, + stream_index: Arc, } impl CudaDevice { pub(crate) fn new(ordinal: usize) -> Result { let context = cudarc::driver::CudaContext::new(ordinal).w()?; let stream = context.new_stream().w()?; + // Create a pool of 8 streams for concurrent kernel execution + let mut pool = Vec::with_capacity(8); + for _ in 0..8 { + pool.push(context.new_stream().w()?); + } + let streams = Arc::new(pool); + let stream_index = Arc::new(AtomicUsize::new(0)); Ok(Self { context, stream, - module: OnceCell::new(), + modules: Arc::new(RwLock::new(vec![])), + streams, + stream_index, }) } + /// Round-robin selection of a stream from the pool + fn select_stream(&self) -> Arc { + let idx = self.stream_index.fetch_add(1, Ordering::SeqCst) % self.streams.len(); + self.streams[idx].clone() + } + pub(crate) fn stream(&self) -> Arc { self.stream.clone() } - pub(crate) fn get_or_load_func(&self, function_name: &str, ptx: Ptx) -> Result { - let module = self - .module - .get_or_init(|| self.context.load_module(ptx).w().unwrap()); - module.load_function(function_name).w() + pub(crate) fn load_func(&self, function_name: &str, ptx: Ptx) -> Result { + let module = self.context.load_module(ptx).w()?; + let func = module.load_function(function_name).w()?; + self.modules.write().unwrap().push(module); + Ok(func) } } @@ -62,6 +87,7 @@ impl Deref for CudaDevice { pub struct CudaStorage { slice: CudaSlice, device: CudaDevice, + event: CudaEvent, } impl BackendStorage for CudaStorage { @@ -71,6 +97,34 @@ impl BackendStorage for CudaStorage { } } +pub enum CudaCompiledKernel { + /// JIT‑compiled element‑wise kernel produced by `compile_kernel`. + ElementWise { + func: CudaFunction, + slice: CudaSlice, + shape: Vec, + order: usize, + }, + /// Matrix–multiplication kernel to be executed through cuBLAS. + MatMul { + l_id: usize, + r_id: usize, + /// Optional output tensor ID for axpby semantics + o_id: Option, + b: usize, + m: usize, + n: usize, + k: usize, + order: usize, + /// scale factor for existing output + alpha: T, + /// scale factor for lhs*rhs + beta: T, + cublas: cudarc::cublas::CudaBlas, + stream: Arc, + }, +} + #[derive(Debug)] struct Name(usize); impl Name { @@ -83,36 +137,21 @@ impl Name { fn handle_node( current_name: &mut usize, header: &mut String, - op: &Op, - graph: &[Op], + op: &GraphNode, + graph: &[GraphNode], ) -> String { - match op { + match &op.op { Op::BinaryOp { l_id, r_id, operator, } => { - let l_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(l_id)], - graph, - ); - let r_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(r_id)], - graph, - ); + let l_name = handle_node(current_name, header, &graph[l_id.get()], graph); + let r_name = handle_node(current_name, header, &graph[r_id.get()], graph); format!("({l_name} {} {r_name})", operator.as_c_op()) } Op::UnaryOp { v_id, operator } => { - let v_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(v_id)], - graph, - ); + let v_name = handle_node(current_name, header, &graph[v_id.get()], graph); operator.fill_in_c_op(v_name) } Op::Fill { v } => { @@ -121,8 +160,11 @@ fn handle_node( *header += &format!("T {} = {v:?};\n", name.to_name()); format!("({})", name.to_name()) } - Op::Arange { start, step, stop } => { - compile_error!("arange is not implemented for CUDA yet."); + Op::Arange { + start, + step, + stop: _, + } => { *current_name += 1; let name = Name(*current_name); *header += &format!( @@ -132,24 +174,9 @@ fn handle_node( format!("({})", name.to_name()) } Op::FusedMulAdd { a_id, b_id, c_id } => { - let a_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(a_id)], - graph, - ); - let b_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(b_id)], - graph, - ); - let c_name = handle_node( - current_name, - header, - &graph[<&GraphTensorId as Into>::into(c_id)], - graph, - ); + let a_name = handle_node(current_name, header, &graph[a_id.get()], graph); + let b_name = handle_node(current_name, header, &graph[b_id.get()], graph); + let c_name = handle_node(current_name, header, &graph[c_id.get()], graph); #[cfg(feature = "slow_integral_fma_cuda")] if T::INTEGRAL { use crate::graph::BinaryOpType; @@ -163,6 +190,7 @@ fn handle_node( format!("( static_cast(fma(static_cast({a_name}), static_cast({b_name}), static_cast({c_name}))))") } Op::NoOp => unreachable!("no-op ops should never be reached."), + Op::MatMul { .. } => unreachable!("matmul op should have its own split!"), } } @@ -208,6 +236,7 @@ fn compile_ptx(template_kernel: String) -> Result { .join("include") .display() .to_string()], + arch: Some("sm_90"), ..Default::default() }, ) @@ -215,11 +244,39 @@ fn compile_ptx(template_kernel: String) -> Result { } impl CudaDevice { - fn run_graph( + fn run_kernel( + &self, + func: &CudaFunction, + data: &CudaSlice, + shape: &[usize], + ) -> Result> { + let n_elems: usize = shape.iter().product(); + let stream = self.select_stream(); + + let cfg = LaunchConfig::for_num_elems(n_elems as u32); + + let mut builder = stream.launch_builder(func); + builder.arg(data); + builder.arg(&n_elems); + unsafe { builder.launch(cfg).w()? }; + + // Record an event once this kernel completes + let event = self.context.new_event(None).w()?; + event.record(&stream).w()?; + + Ok(CudaStorage { + slice: data.clone(), + device: self.clone(), + event, + }) + } + + fn compile_kernel( &self, header: String, body: String, - ) -> Result> { + shape: Vec, + ) -> Result<(CudaFunction, CudaSlice)> { // Module name is based on hash of body and header let mut hasher = DefaultHasher::new(); body.hash(&mut hasher); @@ -251,22 +308,8 @@ impl CudaDevice { T::C_NAME, ); - let ptx = if let Some(home) = dirs::home_dir() { - let path = format!( - "{}/.cache/constensor/ptx/{function_name}.ptx", - home.display() - ); - if Path::new(&path).exists() { - match fs::read_to_string(path) { - Ok(ptx) => Ptx::from_src(ptx), - Err(_) => compile_ptx(template_kernel)?, - } - } else { - compile_ptx(template_kernel)? - } - } else { - compile_ptx(template_kernel)? - }; + // Always recompile PTX to avoid using stale cached files + let ptx = compile_ptx(template_kernel.clone())?; let ptx_str = ptx.to_src(); if let Some(home) = dirs::home_dir() { @@ -281,36 +324,232 @@ impl CudaDevice { fs::write(path, ptx_str)?; } - let n_elems = S::element_count(); + let n_elems = shape.iter().product(); let stream = self.stream(); let data = unsafe { stream.alloc::(n_elems) }.w()?; - let func = self.get_or_load_func(&function_name, ptx)?; - - let cfg = LaunchConfig::for_num_elems(n_elems as u32); - - let mut builder = stream.launch_builder(&func); - builder.arg(&data); - builder.arg(&n_elems); - unsafe { builder.launch(cfg).w()? }; + let func = self.load_func(&function_name, ptx)?; - Ok(CudaStorage { - slice: data, - device: self.clone(), - }) + Ok((func, data)) } } impl BackendDevice for CudaDevice { type Storage = CudaStorage; - fn compile_and_run_graph( + fn compile( + &self, + graph: Vec>, + ) -> Result> { + // Build a dependency graph of tensor indices + let mut dep_graph = DiGraphMap::::new(); + for idx in 0..graph.len() { + dep_graph.add_node(idx); + } + + for (idx, node) in graph.iter().enumerate() { + match &node.op { + Op::BinaryOp { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); + } + Op::UnaryOp { v_id, .. } => { + dep_graph.add_edge(v_id.get(), idx, ()); + } + Op::FusedMulAdd { a_id, b_id, c_id } => { + dep_graph.add_edge(a_id.get(), idx, ()); + dep_graph.add_edge(b_id.get(), idx, ()); + dep_graph.add_edge(c_id.get(), idx, ()); + } + Op::MatMul { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); + } + // NoOp and Fill/Arange don’t create incoming edges + Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} + } + } + + // Compute topological order + let order = toposort(&dep_graph, None).expect("Cycle detected in graph!"); + + // New kernel and grouping logic with matmul input tracking + let mut kernels = Vec::>::new(); + let mut matmuls = Vec::>::new(); + let mut splits: Vec<(Vec, Vec)> = Vec::new(); + // Collect all matmul input node indices + let mut matmul_inputs = HashSet::new(); + for &idx in &order { + if let Op::MatMul { + l_id, r_id, o_id, .. + } = &graph[idx].op + { + matmul_inputs.insert(l_id.get()); + matmul_inputs.insert(r_id.get()); + if let Some(o_id) = o_id { + matmul_inputs.insert(o_id.get()); + } + } + } + + for &idx in &order { + match &graph[idx].op { + Op::MatMul { + l_id, + r_id, + o_id, + k, + alpha, + beta, + } => { + let l_shape = &graph[l_id.get()].shape; + let r_shape = &graph[r_id.get()].shape; + assert_eq!(l_shape.len(), 3); + assert_eq!(r_shape.len(), 3); + let (b, m, _k) = (l_shape[0], l_shape[1], l_shape[2]); + let n = r_shape[2]; + + // Select our stream + let stream = self.select_stream(); + let cublas = CudaBlas::new(stream.clone()).unwrap(); + + matmuls.push(CudaCompiledKernel::MatMul { + l_id: l_id.get(), + r_id: r_id.get(), + o_id: o_id.as_ref().map(|id| id.get()), + b, + m, + n, + k: *k, + order: idx, + alpha: *alpha, + beta: *beta, + cublas, + stream, + }); + } + _ => { + let shape_key = graph[idx].shape.clone(); + let should_group = if let Some((last_group, _)) = splits.last_mut() { + let last_idx = *last_group.last().unwrap(); + let last_shape_key = graph[last_idx].shape.clone(); + // Force all matmul inputs to have their own + last_shape_key == shape_key && !matmul_inputs.contains(&idx) + } else { + false + }; + if should_group { + splits.last_mut().unwrap().0.push(idx); + } else { + splits.push((vec![idx], shape_key)); + } + } + } + } + + // Compile element‑wise splits first so matmul inputs are ready + for (sub_order, shape) in splits { + let mut header = String::new(); + let body = handle_node( + &mut 0, + &mut header, + &graph[*sub_order.last().unwrap()], + &graph, + ); + let (func, slice) = + self.compile_kernel::(header.clone(), body.clone(), shape.clone())?; + kernels.push(CudaCompiledKernel::ElementWise { + func, + slice, + shape, + order: *sub_order.iter().max().unwrap(), + }); + } + // Then append all MatMul kernels + kernels.extend(matmuls); + + Ok(CompiledGraph::Cuda { + kernels, + ghost: PhantomData, + }) + } + + fn run_graph( &self, - nodes: &[crate::Op], + graph: &CompiledGraph, ) -> Result> { - let mut header = "".to_string(); - let body = handle_node(&mut 0, &mut header, nodes.last().unwrap(), nodes); - self.run_graph::(header, body) + #[allow(irrefutable_let_patterns)] + let CompiledGraph::Cuda { kernels, ghost: _ } = graph + else { + unreachable!() + }; + + // For each group of nodes with matching input shapes/dtype, generate and run kernels + let mut last_storage = HashMap::new(); + for kernel in kernels { + match kernel { + CudaCompiledKernel::ElementWise { + func, + slice, + shape, + order, + } => { + let storage = self.run_kernel::(func, slice, shape)?; + last_storage.insert(order, storage); + } + CudaCompiledKernel::MatMul { + l_id, + r_id, + o_id, + b, + m, + n, + k, + order, + alpha, + beta, + cublas, + stream, + } => { + // obtain input buffers + let lhs = last_storage.get(&l_id).expect("lhs storage missing"); + let rhs = last_storage.get(&r_id).expect("rhs storage missing"); + + // Wait for prior kernels + lhs.event.synchronize().w()?; + rhs.event.synchronize().w()?; + + let elems = m * n; + // prepare output buffer, copy initial if provided + let mut out = unsafe { stream.alloc::(elems) }.w()?; + if let Some(o_idx) = o_id { + let init = last_storage.get(&o_idx).expect("output storage missing"); + // ensure the initial output is ready + init.event.synchronize().w()?; + self.stream().memcpy_dtod(&init.slice, &mut out).w()?; + } + + // Launch GEMM on the pooled stream + T::launch_gemm_cuda( + cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha, + )?; + + // Record completion event for the MatMul result + let event = self.context.new_event(None).w()?; + event.record(stream).w()?; + + let storage = CudaStorage { + slice: out, + device: self.clone(), + event, + }; + last_storage.insert(order, storage); + } + } + } + + let key = *last_storage.keys().max().unwrap(); + Ok(last_storage.remove(&key).unwrap()) } } diff --git a/constensor-core/src/cuda_backend/util.rs b/constensor-core/src/cuda_backend/util.rs new file mode 100644 index 0000000..9ac05ff --- /dev/null +++ b/constensor-core/src/cuda_backend/util.rs @@ -0,0 +1,97 @@ +use cudarc::cublas::{GemmConfig, StridedBatchedConfig}; + +use crate::{Error, Result}; + +pub(crate) fn gemm_config( + alpha: T, + beta: T, + (b, m, n, k): (usize, usize, usize, usize), +) -> Result> { + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm + use cudarc::cublas::sys::cublasOperation_t; + + let lhs_stride = [m * k, k, 1]; + let rhs_stride = [k * n, n, 1]; + let lhs_dims = [b, m, k]; + let rhs_dims = [b, k, n]; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // The a tensor has dims batching, k, n (rhs) + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, cublasOperation_t::CUBLAS_OP_N) + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { + (k as i32, cublasOperation_t::CUBLAS_OP_T) + } else { + Err(Error::MatMulNonContiguous { + lhs_stride, + rhs_stride, + mnk: (m, n, k), + })? + }; + // The b tensor has dims batching, m, k (lhs) + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, cublasOperation_t::CUBLAS_OP_N) + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { + (m as i32, cublasOperation_t::CUBLAS_OP_T) + } else { + Err(Error::MatMulNonContiguous { + lhs_stride, + rhs_stride, + mnk: (m, n, k), + })? + }; + // The setup below was copied from: + // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 + let gemm = GemmConfig { + alpha, + beta, + m: n as i32, + n: m as i32, + k: k as i32, + lda, + ldb, + ldc: n as i32, + transa, + transb, + }; + + let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { + [s1, stride] if s1 == stride * lhs_dims[1] => stride, + [_, stride] if lhs_dims[0] == 1 => stride, + [stride, _] if lhs_dims[1] == 1 => stride, + [stride] => stride, + [] => m * k, + _ => Err(Error::MatMulNonContiguous { + lhs_stride, + rhs_stride, + mnk: (m, n, k), + })?, + }; + let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { + [s1, stride] if s1 == stride * rhs_dims[1] => stride, + [_, stride] if rhs_dims[0] == 1 => stride, + [stride, _] if rhs_dims[1] == 1 => stride, + [stride] => stride, + [] => n * k, + _ => Err(Error::MatMulNonContiguous { + lhs_stride, + rhs_stride, + mnk: (m, n, k), + })?, + }; + + Ok(StridedBatchedConfig { + batch_size: b as i32, + gemm, + stride_a: stride_a as i64, + stride_b: stride_b as i64, + stride_c: (m * n) as i64, + }) +} diff --git a/constensor-core/src/device.rs b/constensor-core/src/device.rs index 4c181fd..7825235 100644 --- a/constensor-core/src/device.rs +++ b/constensor-core/src/device.rs @@ -3,7 +3,7 @@ use crate::cuda_backend::CudaDevice; use crate::{ cpu_storage::CpuDevice, storage::{BackendDevice, Storage}, - DType, GraphNode, Result, + CompiledGraph, DType, GraphNode, Result, Shape, }; /// Marker trait for devices @@ -66,11 +66,25 @@ pub enum Device { } impl Device { - pub fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { + pub fn run_graph( + &self, + graph: &CompiledGraph, + ) -> Result> { match self { #[cfg(feature = "cuda")] - Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.compile_and_run_graph::(graph)?)), - Self::Cpu => Ok(Storage::Cpu(CpuDevice.compile_and_run_graph::(graph)?)), + Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.run_graph::(graph)?)), + Self::Cpu => Ok(Storage::Cpu(CpuDevice.run_graph::(graph)?)), + } + } + + pub fn compile( + &self, + graph: Vec>, + ) -> Result> { + match self { + #[cfg(feature = "cuda")] + Self::Cuda(cuda) => cuda.compile::(graph), + Self::Cpu => CpuDevice.compile::(graph), } } } diff --git a/constensor-core/src/dtype/gemm.rs b/constensor-core/src/dtype/gemm.rs index 1bff148..57c22f1 100644 --- a/constensor-core/src/dtype/gemm.rs +++ b/constensor-core/src/dtype/gemm.rs @@ -23,6 +23,94 @@ pub trait GemmDispatch { beta: Self, ) where Self: Sized; + + #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] + // Matrix multiplication: (B x M x K) * (B x K x N) = (B x M x N) + fn launch_gemm_cuda( + cublas: &cudarc::cublas::CudaBlas, + lhs: &cudarc::driver::CudaSlice, + rhs: &cudarc::driver::CudaSlice, + b: usize, + m: usize, + n: usize, + k: usize, + out: &mut cudarc::driver::CudaSlice, + alpha: Self, + beta: Self, + ) -> crate::Result<()> + where + Self: Sized; +} + +macro_rules! instantiate_gemm_cuda { + (u8) => { + instantiate_gemm_cuda!(__instantiate_fail); + }; + (u32) => { + instantiate_gemm_cuda!(__instantiate_fail); + }; + (i32) => { + instantiate_gemm_cuda!(__instantiate_fail); + }; + (i64) => { + instantiate_gemm_cuda!(__instantiate_fail); + }; + + (__instantiate_fail) => { + #[cfg(feature = "cuda")] + fn launch_gemm_cuda( + _cublas: &cudarc::cublas::CudaBlas, + _lhs: &cudarc::driver::CudaSlice, + _rhs: &cudarc::driver::CudaSlice, + _b: usize, + _m: usize, + _n: usize, + _k: usize, + _out: &mut cudarc::driver::CudaSlice, + _alpha: Self, + _beta: Self, + ) -> crate::Result<()> + where + Self: Sized, + { + panic!("`launch_gemm_cuda` called with invalid configuration (w/o CUDA, dtype)") + } + }; + + ($rt:ident) => { + #[cfg(feature = "cuda")] + fn launch_gemm_cuda( + cublas: &cudarc::cublas::CudaBlas, + lhs: &cudarc::driver::CudaSlice<$rt>, + rhs: &cudarc::driver::CudaSlice<$rt>, + b: usize, + m: usize, + n: usize, + k: usize, + out: &mut cudarc::driver::CudaSlice<$rt>, + alpha: $rt, + beta: $rt, + ) -> crate::Result<()> { + use crate::cuda_backend::error::WrapErr; + use cudarc::cublas::Gemm; + + let gemm_cfg = crate::cuda_backend::util::gemm_config(alpha, beta, (b, m, n, k))?; + + unsafe { + cublas + .gemm_strided_batched( + gemm_cfg, + &lhs.as_view(), + &rhs.as_view(), + &mut out.as_view_mut(), + ) + .w()?; + } + + Ok(()) + } + }; } macro_rules! instantiate_gemm { @@ -54,6 +142,8 @@ macro_rules! instantiate_gemm { } } } + + instantiate_gemm_cuda!($rt); } }; @@ -121,6 +211,8 @@ macro_rules! instantiate_gemm { } } } + + instantiate_gemm_cuda!($rt); } }; // SIMD-accelerated gemm using SimdSupported for vectorized operations along 'n' dimension @@ -218,6 +310,8 @@ macro_rules! instantiate_gemm { } } } + + instantiate_gemm_cuda!($rt); } }; } diff --git a/constensor-core/src/dtype/mod.rs b/constensor-core/src/dtype/mod.rs index 7a39cc6..56bf6bd 100644 --- a/constensor-core/src/dtype/mod.rs +++ b/constensor-core/src/dtype/mod.rs @@ -143,7 +143,6 @@ maybe_neg!(i64); maybe_neg!(f32); maybe_neg!(f64); -#[cfg(not(feature = "cuda"))] /// Marker trait for tensor datatypes. pub trait DType: Debug + Clone + DTypeOps + Send + Sync + MaybeNeg + DeviceReprLike + 'static diff --git a/constensor-core/src/error.rs b/constensor-core/src/error.rs index 983affc..8ed3fab 100644 --- a/constensor-core/src/error.rs +++ b/constensor-core/src/error.rs @@ -1,3 +1,5 @@ +use std::{convert::Infallible, fmt::Display}; + #[derive(thiserror::Error, Debug)] pub enum Error { #[error(transparent)] @@ -14,11 +16,36 @@ pub enum Error { #[error("IO error: {0}")] IoError(String), + + /// Arbitrary errors wrapping. + #[error(transparent)] + Wrapped(Box), + + /// Arbitrary errors wrapping with context. + #[error("{wrapped:?}\n{context:?}")] + WrappedContext { + wrapped: Box, + context: String, + }, + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: [usize; 3], + rhs_stride: [usize; 3], + mnk: (usize, usize, usize), + }, } pub type Result = std::result::Result; impl Error { + /// Create a new error based on a printable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. + pub fn msg(msg: M) -> Self { + Self::Msg(msg.to_string()).bt() + } + pub fn bt(self) -> Self { let backtrace = std::backtrace::Backtrace::capture(); match backtrace.status() { @@ -37,3 +64,99 @@ impl From for Error { Error::IoError(value.to_string()) } } + +#[macro_export] +macro_rules! bail { + ($msg:literal $(,)?) => { + return Err($crate::Error::Msg(format!($msg).into()).bt()) + }; + ($err:expr $(,)?) => { + return Err($crate::Error::Msg(format!($err).into()).bt()) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt()) + }; +} + +pub(crate) mod private { + pub trait Sealed {} + + impl Sealed for std::result::Result where E: std::error::Error {} + impl Sealed for Option {} +} + +/// Attach more context to an error. +/// +/// Inspired by [`anyhow::Context`]. +pub trait Context: private::Sealed { + /// Wrap the error value with additional context. + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context(self, f: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl Context for std::result::Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using map_err to save 2 useless frames off the captured backtrace + // in ext_context. + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context.to_string(), + }), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context().to_string(), + }), + } + } +} + +impl Context for Option { + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using ok_or_else to save 2 useless frames off the captured + // backtrace. + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context)), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context())), + } + } +} diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index f71f229..152fc75 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -5,13 +5,14 @@ use std::{ fmt::Display, fs, hash::Hash, + marker::PhantomData, path::Path, process::Command, rc::Rc, sync::{Arc, RwLock, RwLockReadGuard}, }; -use crate::{DType, Result, Shape}; +use crate::{device::Dev, tensor::concretetensor::from_storage, DType, Result, Shape, Tensor}; use petgraph::Graph as PetGraph; use petgraph::{dot::Dot, graph::NodeIndex}; @@ -441,6 +442,51 @@ impl Graph { self.optimize_inplace_fma(); self.optimize_inplace_matmul(); } + + pub fn compile(self) -> Result> { + if self + .data + .read() + .unwrap() + .last() + .is_some_and(|last| last.shape != S::shape()) + { + let read = self.data.read(); + let last = read.as_ref().unwrap().last().unwrap(); + + crate::bail!( + "Graph compiled shape is {:?} does not match the last node shape {:?}!", + &last.shape, + S::shape() + ); + } + + let device = D::resolve()?; + + device.compile(self.data.read().unwrap().clone()) + } +} + +/// A representation of the compiled graph. The shape is the output shape. +pub enum CompiledGraph { + Cpu { + order: Vec, + graph: Vec>, + ghost: PhantomData<(S, T, D)>, + }, + #[cfg(feature = "cuda")] + Cuda { + kernels: Vec>, + ghost: PhantomData<(S, T, D)>, + }, +} + +impl CompiledGraph { + pub fn run(&self) -> Result> { + let device = D::resolve()?; + let storage = device.run_graph(self)?; + Ok(from_storage(Arc::new(storage))) + } } #[derive(PartialEq, Debug, Clone, Copy)] diff --git a/constensor-core/src/lib.rs b/constensor-core/src/lib.rs index c8af30a..a614f3e 100644 --- a/constensor-core/src/lib.rs +++ b/constensor-core/src/lib.rs @@ -15,7 +15,7 @@ pub use device::Cpu; #[cfg(feature = "cuda")] pub use device::Cuda; pub use dtype::DType; -pub use error::{Error, Result}; -pub use graph::{Graph, GraphNode, Op}; +pub use error::{Context, Error, Result}; +pub use graph::{CompiledGraph, Graph, GraphNode, Op}; pub use shape::{Shape, R1, R2, R3, R4, R5, R6}; pub use tensor::{GraphTensor, Tensor}; diff --git a/constensor-core/src/storage.rs b/constensor-core/src/storage.rs index 2d10bf6..8394207 100644 --- a/constensor-core/src/storage.rs +++ b/constensor-core/src/storage.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; #[cfg(feature = "cuda")] use crate::cuda_backend::CudaStorage; -use crate::{cpu_storage::CpuStorage, DType, GraphNode, Result}; +use crate::{cpu_storage::CpuStorage, device::Dev, CompiledGraph, DType, GraphNode, Result, Shape}; pub enum Storage { #[cfg(feature = "cuda")] @@ -27,5 +27,12 @@ pub trait BackendStorage { pub trait BackendDevice { type Storage: BackendStorage; - fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result>; + fn compile( + &self, + graph: Vec>, + ) -> Result>; + fn run_graph( + &self, + graph: &CompiledGraph, + ) -> Result>; } diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index 6269718..1872a4d 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -7,8 +7,7 @@ use std::{ use crate::{ device::Dev, graph::{BinaryOpType, Graph, GraphTensorId, Op, UnaryOpType}, - tensor::concretetensor::from_storage, - DType, Result, Shape, Tensor, R1, R3, + DType, Shape, R1, R3, }; /// A tensor representing an intermediary result of a graph. Performing operations @@ -126,16 +125,6 @@ impl GraphTensor { pub fn id(&self) -> GraphTensorId { self.id.clone() } - - /// Convert this `GraphTensor` into a concrete `Tensor`. - pub fn to_tensor(self) -> Result> { - let graph = self.graph.read().unwrap(); - let nodes = &*graph.get_ops(); - - let device = D::resolve()?; - let storage = device.compile_and_run_graph::(nodes)?; - Ok(from_storage(Arc::new(storage))) - } } impl GraphTensor, T, D> { diff --git a/constensor-core/tests/fma.rs b/constensor-core/tests/fma.rs index ed3a838..1bf9ca3 100644 --- a/constensor-core/tests/fma.rs +++ b/constensor-core/tests/fma.rs @@ -1,6 +1,6 @@ #[cfg(feature = "cuda")] use constensor_core::Cuda; -use constensor_core::{Cpu, Graph, GraphTensor, R2}; +use constensor_core::{CompiledGraph, Cpu, Graph, GraphTensor, R2}; macro_rules! test_for_device_fma { ($dev:ty, $name:ident) => { @@ -12,8 +12,9 @@ macro_rules! test_for_device_fma { let a = GraphTensor::, f32, $dev>::fill(&mut graph, 2.0); let b = GraphTensor::, f32, $dev>::fill(&mut graph, 3.0); let c = GraphTensor::, f32, $dev>::fill(&mut graph, 4.0); - let res = a * b + c; - let tensor = res.to_tensor().unwrap(); + let _res = a * b + c; + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![10.0; 4]; 3],); } @@ -23,8 +24,9 @@ macro_rules! test_for_device_fma { let a = GraphTensor::, i32, $dev>::fill(&mut graph, 2); let b = GraphTensor::, i32, $dev>::fill(&mut graph, 3); let c = GraphTensor::, i32, $dev>::fill(&mut graph, 4); - let res = a * b + c; - let tensor = res.to_tensor().unwrap(); + let _res = a * b + c; + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![10; 4]; 3],); } } @@ -47,8 +49,9 @@ macro_rules! test_for_device_half_fma { GraphTensor::, f16, $dev>::fill(&mut graph, f16::from_f64_const(3.0)); let c = GraphTensor::, f16, $dev>::fill(&mut graph, f16::from_f64_const(4.0)); - let res = a * b + c; - let tensor = res.to_tensor().unwrap(); + let _res = a * b + c; + let compiled: CompiledGraph, f16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![f16::from_f64_const(10.0); 4]; 3], @@ -80,8 +83,9 @@ macro_rules! test_for_device_bfloat_fma { &mut graph, bf16::from_f64_const(4.0), ); - let res = a * b + c; - let tensor = res.to_tensor().unwrap(); + let _res = a * b + c; + let compiled: CompiledGraph, bf16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![bf16::from_f64_const(10.0); 4]; 3], diff --git a/constensor-core/tests/ops.rs b/constensor-core/tests/ops.rs index deba8ec..f4b7dad 100644 --- a/constensor-core/tests/ops.rs +++ b/constensor-core/tests/ops.rs @@ -1,6 +1,6 @@ #[cfg(feature = "cuda")] use constensor_core::Cuda; -use constensor_core::{Cpu, Graph, GraphTensor, R1, R2, R3}; +use constensor_core::{CompiledGraph, Cpu, Graph, GraphTensor, R1, R2, R3}; #[cfg(feature = "bfloat")] use half::bf16; #[cfg(feature = "half")] @@ -13,8 +13,9 @@ macro_rules! test_for_device_float { #[test] fn fill() { let mut graph = Graph::empty(); - let gt = GraphTensor::, f32, $dev>::fill(&mut graph, 0.0); - let tensor = gt.to_tensor().unwrap(); + let _gt = GraphTensor::, f32, $dev>::fill(&mut graph, 0.0); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![ @@ -32,8 +33,9 @@ macro_rules! test_for_device_float { let y = GraphTensor::, f32, $dev>::fill(&mut graph, 2.0); let z = GraphTensor::, f32, $dev>::fill(&mut graph, 4.0); let c = x + y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![ @@ -49,8 +51,9 @@ macro_rules! test_for_device_float { let mut graph = Graph::empty(); let x = GraphTensor::, f32, $dev>::fill(&mut graph, 1.0); let y = GraphTensor::, f32, $dev>::arange(&mut graph, 0.0, 1.0); - let res = x + y; - let tensor = res.to_tensor().unwrap(); + let _res = x + y; + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![1.0, 1.25, 1.5, 1.75]); } @@ -59,8 +62,9 @@ macro_rules! test_for_device_float { let mut graph = Graph::empty(); let a = GraphTensor::, f32, $dev>::ones(&mut graph); let b = GraphTensor::, f32, $dev>::ones(&mut graph); - let c = a.matmul(b); - let tensor = c.to_tensor().unwrap(); + let _c = a.matmul(b); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); let expected: [Vec<[f32; 2]>; 1] = [vec![[3.0, 3.0], [3.0, 3.0]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); } @@ -71,8 +75,9 @@ macro_rules! test_for_device_float { let a = GraphTensor::, f32, $dev>::ones(&mut graph); let b = GraphTensor::, f32, $dev>::ones(&mut graph); let o = GraphTensor::, f32, $dev>::ones(&mut graph); - let c = a.matmul_axpby(b, o, 1., 1.); - let tensor = c.to_tensor().unwrap(); + let _c = a.matmul_axpby(b, o, 1., 1.); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); let expected: [Vec<[f32; 2]>; 1] = [vec![[4.0, 4.0], [4.0, 4.0]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); } @@ -91,8 +96,9 @@ macro_rules! test_for_device_int { #[test] fn fill() { let mut graph = Graph::empty(); - let gt = GraphTensor::, i32, $dev>::fill(&mut graph, 0); - let tensor = gt.to_tensor().unwrap(); + let _gt = GraphTensor::, i32, $dev>::fill(&mut graph, 0); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![[0, 0, 0, 0,], [0, 0, 0, 0,], [0, 0, 0, 0,],], @@ -106,8 +112,9 @@ macro_rules! test_for_device_int { let y = GraphTensor::, i32, $dev>::fill(&mut graph, 2); let z = GraphTensor::, i32, $dev>::fill(&mut graph, 4); let c = x + y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![[1, 1, 1, 1,], [1, 1, 1, 1,], [1, 1, 1, 1,],], @@ -119,30 +126,35 @@ macro_rules! test_for_device_int { let mut graph = Graph::empty(); let x = GraphTensor::, i32, $dev>::fill(&mut graph, 1); let y = GraphTensor::, i32, $dev>::arange(&mut graph, 0, 4); - let res = x + y; - let tensor = res.to_tensor().unwrap(); + let _res = x + y; + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![1, 2, 3, 4]); } + #[cfg(not(feature = "cuda"))] #[test] fn matmul() { let mut graph = Graph::empty(); let a = GraphTensor::, i32, $dev>::ones(&mut graph); let b = GraphTensor::, i32, $dev>::ones(&mut graph); - let c = a.matmul(b); - let tensor = c.to_tensor().unwrap(); + let _c = a.matmul(b); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); let expected: [Vec<[i32; 2]>; 1] = [vec![[3, 3], [3, 3]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); } + #[cfg(not(feature = "cuda"))] #[test] fn matmul_axpby() { let mut graph = Graph::empty(); let a = GraphTensor::, i32, $dev>::ones(&mut graph); let b = GraphTensor::, i32, $dev>::ones(&mut graph); let o = GraphTensor::, i32, $dev>::ones(&mut graph); - let c = a.matmul_axpby(b, o, 1, 1); - let tensor = c.to_tensor().unwrap(); + let _c = a.matmul_axpby(b, o, 1, 1); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); let expected: [Vec<[i32; 2]>; 1] = [vec![[4, 4], [4, 4]]]; assert_eq!(tensor.data().unwrap().to_vec(), expected); } @@ -162,9 +174,10 @@ macro_rules! test_for_device_half { #[test] fn fill() { let mut graph = Graph::empty(); - let gt = + let _gt = GraphTensor::, f16, $dev>::fill(&mut graph, f16::from_f64_const(0.0)); - let tensor = gt.to_tensor().unwrap(); + let compiled: CompiledGraph, f16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![f16::from_f64_const(0.0); 4]; 3], @@ -181,8 +194,9 @@ macro_rules! test_for_device_half { let z = GraphTensor::, f16, $dev>::fill(&mut graph, f16::from_f64_const(4.0)); let c = x + y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, f16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![f16::from_f64_const(1.3330078); 4]; 3], @@ -198,8 +212,9 @@ macro_rules! test_for_device_half { f16::from_f64_const(0.0), f16::from_f64_const(1.0), ); - let res = x + y; - let tensor = res.to_tensor().unwrap(); + let _res = x + y; + let compiled: CompiledGraph, f16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![ @@ -227,11 +242,12 @@ macro_rules! test_for_device_bfloat { #[test] fn fill() { let mut graph = Graph::empty(); - let gt = GraphTensor::, bf16, $dev>::fill( + let _gt = GraphTensor::, bf16, $dev>::fill( &mut graph, bf16::from_f64_const(0.0), ); - let tensor = gt.to_tensor().unwrap(); + let compiled: CompiledGraph, bf16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![bf16::from_f64_const(0.0); 4]; 3], @@ -254,8 +270,9 @@ macro_rules! test_for_device_bfloat { bf16::from_f64_const(4.0), ); let c = x + y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, bf16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![vec![bf16::from_f64_const(1.3330078); 4]; 3], @@ -272,8 +289,9 @@ macro_rules! test_for_device_bfloat { bf16::from_f64_const(0.0), bf16::from_f64_const(1.0), ); - let res = x + y; - let tensor = res.to_tensor().unwrap(); + let _res = x + y; + let compiled: CompiledGraph, bf16, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!( tensor.data().unwrap().to_vec(), vec![ @@ -305,8 +323,9 @@ macro_rules! test_for_device_float_unary { let y = GraphTensor::, f32, $dev>::fill(&mut graph, 2.0); let z = GraphTensor::, f32, $dev>::fill(&mut graph, 4.0); let c = x + -y; - let res = z / c; - let tensor = res.to_tensor().unwrap(); + let _res = z / c; + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![-4.0; 4]; 3],); } } @@ -326,8 +345,9 @@ macro_rules! test_for_device_sqrt { fn sqrt_float() { let mut graph = Graph::empty(); let x = GraphTensor::, f32, $dev>::fill(&mut graph, 4.0); - let res = x.sqrt(); - let tensor = res.to_tensor().unwrap(); + let _res = x.sqrt(); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![2.0; 4]; 3],); } @@ -335,8 +355,9 @@ macro_rules! test_for_device_sqrt { fn sqrt_int() { let mut graph = Graph::empty(); let x = GraphTensor::, i32, $dev>::fill(&mut graph, 5); - let res = x.sqrt(); - let tensor = res.to_tensor().unwrap(); + let _res = x.sqrt(); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![2; 4]; 3],); } }