diff --git a/Cargo.lock b/Cargo.lock index 37c3b4e..ab9dffc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,32 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "bytemuck" +version = "1.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ecc273b49b3205b83d648f0690daa588925572cc5063745bfe547fe7ec8e1a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "cfg-if" version = "1.0.0" @@ -32,7 +58,9 @@ version = "0.1.0" dependencies = [ "cudarc", "dirs", + "gemm", "half", + "num_cpus", "petgraph", "rayon", "thiserror", @@ -100,12 +128,33 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "dyn-stack" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490bd48eb68fffcfed519b4edbfd82c69cbe741d175b84f0e0cbe8c57cbe0bdd" +dependencies = [ + "bytemuck", +] + [[package]] name = "either" version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -124,6 +173,125 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "gemm" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "libm", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", + "sysctl", +] + +[[package]] +name = "gemm-f16" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -141,6 +309,7 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" dependencies = [ + "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -159,6 +328,18 @@ dependencies = [ "foldhash", ] +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "indexmap" version = "2.9.0" @@ -187,9 +368,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.8" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "c9627da5196e5d8ed0b0495e61e518847578da83483c37288316d9b2e03a7f72" [[package]] name = "libredox" @@ -201,6 +382,16 @@ dependencies = [ "libc", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "bytemuck", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -211,12 +402,34 @@ dependencies = [ "libm", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + [[package]] name = "option-ext" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "petgraph" version = "0.8.1" @@ -238,6 +451,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "pulp" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95fb7a99b37aaef4c7dd2fd15a819eb8010bfc7a2c2155230d51f497316cad6d" +dependencies = [ + "bytemuck", + "cfg-if", + "libm", + "num-complex", + "reborrow", + "version_check", +] + [[package]] name = "quote" version = "1.0.36" @@ -272,6 +499,15 @@ dependencies = [ "rand", ] +[[package]] +name = "raw-cpuid" +version = "11.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" +dependencies = [ + "bitflags", +] + [[package]] name = "rayon" version = "1.10.0" @@ -292,6 +528,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + [[package]] name = "redox_users" version = "0.4.5" @@ -303,6 +545,21 @@ dependencies = [ "thiserror", ] +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "seq-macro" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" + [[package]] name = "serde" version = "1.0.210" @@ -334,6 +591,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sysctl" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc" +dependencies = [ + "bitflags", + "byteorder", + "enum-as-inner", + "libc", + "thiserror", + "walkdir", +] + [[package]] name = "thiserror" version = "1.0.61" @@ -360,12 +631,37 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/constensor-core/Cargo.toml b/constensor-core/Cargo.toml index ab62960..f4396a8 100644 --- a/constensor-core/Cargo.toml +++ b/constensor-core/Cargo.toml @@ -16,6 +16,8 @@ half = { workspace = true, optional = true } dirs = "5.0.1" rayon = "1.10.0" petgraph.workspace = true +gemm = "0.18" +num_cpus = "1.16.0" [features] default = ["half", "bfloat"] diff --git a/constensor-core/examples/hello_world/main.rs b/constensor-core/examples/hello_world/main.rs index 83be376..f167c8a 100644 --- a/constensor-core/examples/hello_world/main.rs +++ b/constensor-core/examples/hello_world/main.rs @@ -4,10 +4,12 @@ fn main() { let mut graph: Graph = Graph::empty(); let arange = GraphTensor::, f32, Cpu>::arange(&mut graph, 0., 1.); dbg!(&arange.to_tensor().unwrap().data()); - let x = GraphTensor::, f32, Cpu>::fill(&mut graph, 1.0); - let y = GraphTensor::, f32, Cpu>::fill(&mut graph, 2.0); - let z = GraphTensor::, f32, Cpu>::fill(&mut graph, 2.0); - let res = y * x + z; + let a = GraphTensor::, f32, Cpu>::fill(&mut graph, 1.0); + let b = GraphTensor::, f32, Cpu>::fill(&mut graph, 2.0); + let c = GraphTensor::, f32, Cpu>::fill(&mut graph, 3.0); + let d = GraphTensor::, f32, Cpu>::fill(&mut graph, 4.0); + let res = a * b + c; + let res = res + d; graph.optimize(); @@ -15,5 +17,5 @@ fn main() { let tensor: Tensor, f32, Cpu> = res.to_tensor().unwrap(); - assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![4.0; 4]; 3],); + assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![9.0; 4]; 3],); } diff --git a/constensor-core/examples/matmul/main.rs b/constensor-core/examples/matmul/main.rs new file mode 100644 index 0000000..ce2897f --- /dev/null +++ b/constensor-core/examples/matmul/main.rs @@ -0,0 +1,41 @@ +use constensor_core::{Cpu, DType, Graph, GraphTensor, R3}; +use std::time::Instant; + +fn bench( + type_name: &str, + alpha: T, + beta: T, +) { + // Number of times to run the matmul for averaging + let iterations = 1000; + let mut total = std::time::Duration::new(0, 0); + + for _ in 0..iterations { + let start = Instant::now(); + + let mut graph = Graph::empty(); + let a = GraphTensor::, T, Cpu>::ones(&mut graph); + let b = GraphTensor::, T, Cpu>::ones(&mut graph); + let o = GraphTensor::, T, Cpu>::ones(&mut graph); + let c = a.matmul_axpby(b, o, alpha, beta); + + graph.optimize(); + + let _tensor = std::hint::black_box(c.to_tensor().unwrap()); + + total += start.elapsed(); + } + + let avg = total / (iterations as u32); + println!("Average execution time for {type_name} over {iterations} iterations: {avg:?}"); +} + +fn main() { + const B: usize = 1; + const M: usize = 128; + const N: usize = 128; + const K: usize = 128; + + bench::("f32", 1.0, 1.0); + bench::("i32", 1, 1); +} diff --git a/constensor-core/src/cpu_storage/mod.rs b/constensor-core/src/cpu_storage/mod.rs index 0d0cad8..3dd73c1 100644 --- a/constensor-core/src/cpu_storage/mod.rs +++ b/constensor-core/src/cpu_storage/mod.rs @@ -8,9 +8,8 @@ use pool::{BufferPool, PooledBuffer}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; use crate::{ - graph::GraphTensorId, storage::{BackendDevice, BackendStorage}, - DType, Op, Result, Shape, + DType, GraphNode, Op, Result, }; mod pool; @@ -30,10 +29,7 @@ impl BackendStorage for CpuStorage { impl BackendDevice for CpuDevice { type Storage = CpuStorage; - fn compile_and_run_graph( - &self, - graph: &[Op], - ) -> Result> { + fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { { // Create a shared buffer pool let pool = Rc::new(RefCell::new(BufferPool::::new())); @@ -45,27 +41,22 @@ impl BackendDevice for CpuDevice { } for (idx, node) in graph.iter().enumerate() { - match node { - Op::BinaryOp { l_id, r_id, .. } | Op::InplaceBinaryOp { l_id, r_id, .. } => { - let l_idx = <&GraphTensorId as Into>::into(l_id); - let r_idx = <&GraphTensorId as Into>::into(r_id); - dep_graph.add_edge(l_idx, idx, ()); - dep_graph.add_edge(r_idx, idx, ()); + match &node.op { + Op::BinaryOp { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); } Op::UnaryOp { v_id, .. } => { - let v_idx = <&GraphTensorId as Into>::into(v_id); - dep_graph.add_edge(v_idx, idx, ()); + dep_graph.add_edge(v_id.get(), idx, ()); } - Op::FusedMulAdd { a_id, b_id, c_id } - | Op::InplaceFusedMulAdd { - a_id, b_id, c_id, .. - } => { - let a_idx = <&GraphTensorId as Into>::into(a_id); - let b_idx = <&GraphTensorId as Into>::into(b_id); - let c_idx = <&GraphTensorId as Into>::into(c_id); - dep_graph.add_edge(a_idx, idx, ()); - dep_graph.add_edge(b_idx, idx, ()); - dep_graph.add_edge(c_idx, idx, ()); + Op::FusedMulAdd { a_id, b_id, c_id } => { + dep_graph.add_edge(a_id.get(), idx, ()); + dep_graph.add_edge(b_id.get(), idx, ()); + dep_graph.add_edge(c_id.get(), idx, ()); + } + Op::MatMul { l_id, r_id, .. } => { + dep_graph.add_edge(l_id.get(), idx, ()); + dep_graph.add_edge(r_id.get(), idx, ()); } // NoOp and Fill/Arange don’t create incoming edges Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} @@ -82,50 +73,41 @@ impl BackendDevice for CpuDevice { // Evaluate nodes in topological order for idx in order { let op = &graph[idx]; - let computed = match op { + + let out_shape = &op.shape; + let out_elem_count: usize = out_shape.iter().product(); + + let computed = match &op.op { Op::BinaryOp { l_id, r_id, operator, } => { - let l_idx = <&GraphTensorId as Into>::into(l_id); - let r_idx = <&GraphTensorId as Into>::into(r_id); - let l_buf = results[l_idx].as_ref().unwrap(); - let r_buf = results[r_idx].as_ref().unwrap(); - let mut out = pool.borrow_mut().get_buffer(S::element_count()); - T::binary_simd_op(l_buf, r_buf, &mut out, *operator); - PooledBuffer::new(out, pool.clone()) - } - Op::InplaceBinaryOp { - out, - l_id, - r_id, - operator, - } => { - let l_idx = <&GraphTensorId as Into>::into(l_id); - let r_idx = <&GraphTensorId as Into>::into(r_id); - let o_idx = <&GraphTensorId as Into>::into(out); - if o_idx == l_idx { - let mut l_buf = results[l_idx].take().unwrap(); - let r_buf = results[r_idx].as_ref().unwrap(); + if l_id.is_inplace() { + let mut l_buf = results[l_id.get()].take().unwrap(); + let r_buf = results[r_id.get()].as_ref().unwrap(); T::binary_simd_op_inplace_lhs(&mut l_buf, r_buf, *operator); l_buf - } else if o_idx == r_idx { - let mut r_buf = results[r_idx].take().unwrap(); - let l_buf = results[l_idx].as_ref().unwrap(); + } else if r_id.is_inplace() { + let mut r_buf = results[r_id.get()].take().unwrap(); + let l_buf = results[l_id.get()].as_ref().unwrap(); T::binary_simd_op_inplace_rhs(l_buf, &mut r_buf, *operator); r_buf } else { - unreachable!() + let l_buf = results[l_id.get()].as_ref().unwrap(); + let r_buf = results[r_id.get()].as_ref().unwrap(); + let mut out = pool.borrow_mut().get_buffer(out_elem_count); + T::binary_simd_op(l_buf, r_buf, &mut out, *operator); + PooledBuffer::new(out, pool.clone()) } } Op::Fill { v } => { - let mut buf = pool.borrow_mut().get_empty_buffer(S::element_count()); - buf.extend(std::iter::repeat_n(*v, S::element_count())); + let mut buf = pool.borrow_mut().get_empty_buffer(out_elem_count); + buf.extend(std::iter::repeat_n(*v, out_elem_count)); PooledBuffer::new(buf, pool.clone()) } Op::Arange { start, step, stop } => { - let mut buf = pool.borrow_mut().get_empty_buffer(S::element_count()); + let mut buf = pool.borrow_mut().get_empty_buffer(out_elem_count); let mut x = start.to_f64(); while x < stop.to_f64() { buf.push(T::from_f64(x)); @@ -134,63 +116,80 @@ impl BackendDevice for CpuDevice { PooledBuffer::new(buf, pool.clone()) } Op::UnaryOp { v_id, operator } => { - let v_idx = <&GraphTensorId as Into>::into(v_id); - let buf = results[v_idx].as_ref().unwrap(); + let buf = results[v_id.get()].as_ref().unwrap(); let op_fn = operator.to_closure(); - let mut out = pool.borrow_mut().get_buffer(S::element_count()); + let mut out = pool.borrow_mut().get_buffer(out_elem_count); out.par_iter_mut() .zip(&**buf) .for_each(|(out, x): (&mut T, &T)| *out = op_fn(*x)); PooledBuffer::new(out, pool.clone()) } Op::FusedMulAdd { a_id, b_id, c_id } => { - let a_idx = <&GraphTensorId as Into>::into(a_id); - let b_idx = <&GraphTensorId as Into>::into(b_id); - let c_idx = <&GraphTensorId as Into>::into(c_id); - let a_buf = results[a_idx].as_ref().unwrap(); - let b_buf = results[b_idx].as_ref().unwrap(); - let c_buf = results[c_idx].as_ref().unwrap(); - - let mut out = pool.borrow_mut().get_buffer(S::element_count()); - T::fma_op(a_buf, b_buf, c_buf, &mut out); - PooledBuffer::new(out, pool.clone()) - } - Op::InplaceFusedMulAdd { - a_id, - b_id, - c_id, - out, - } => { - let a_idx = <&GraphTensorId as Into>::into(a_id); - let b_idx = <&GraphTensorId as Into>::into(b_id); - let c_idx = <&GraphTensorId as Into>::into(c_id); - let o_idx = <&GraphTensorId as Into>::into(out); - - if o_idx == a_idx { - let mut a_buf = results[a_idx].take().unwrap(); - let b_buf = results[b_idx].as_ref().unwrap(); - let c_buf = results[c_idx].as_ref().unwrap(); + if a_id.is_inplace() { + let mut a_buf = results[a_id.get()].take().unwrap(); + let b_buf = results[b_id.get()].as_ref().unwrap(); + let c_buf = results[c_id.get()].as_ref().unwrap(); T::fma_op_inplace_a(&mut a_buf, b_buf, c_buf); a_buf - } else if o_idx == b_idx { - let mut b_buf = results[b_idx].take().unwrap(); - let a_buf = results[a_idx].as_ref().unwrap(); - let c_buf = results[c_idx].as_ref().unwrap(); + } else if b_id.is_inplace() { + let mut b_buf = results[b_id.get()].take().unwrap(); + let a_buf = results[a_id.get()].as_ref().unwrap(); + let c_buf = results[c_id.get()].as_ref().unwrap(); T::fma_op_inplace_b(a_buf, &mut b_buf, c_buf); b_buf - } else if o_idx == c_idx { - let mut c_buf = results[c_idx].take().unwrap(); - let a_buf = results[a_idx].as_ref().unwrap(); - let b_buf = results[b_idx].as_ref().unwrap(); + } else if c_id.is_inplace() { + let mut c_buf = results[c_id.get()].take().unwrap(); + let a_buf = results[a_id.get()].as_ref().unwrap(); + let b_buf = results[b_id.get()].as_ref().unwrap(); T::fma_op_inplace_c(a_buf, b_buf, &mut c_buf); c_buf } else { - unreachable!() + let a_buf = results[a_id.get()].as_ref().unwrap(); + let b_buf = results[b_id.get()].as_ref().unwrap(); + let c_buf = results[c_id.get()].as_ref().unwrap(); + + let mut out = pool.borrow_mut().get_buffer(out_elem_count); + T::fma_op(a_buf, b_buf, c_buf, &mut out); + PooledBuffer::new(out, pool.clone()) } } + // Matrix multiplication: multiply two 2D tensors A (m x k) and B (k x n) + Op::MatMul { + l_id, + r_id, + o_id, + k, + alpha, + beta, + } => { + // Determine output dimensions from shape S (must be 2D) + let shape = out_shape; + assert!(shape.len() == 3); + let b = shape[0]; + let m = shape[1]; + let n = shape[2]; + + let mut out = if let Some(o_id) = o_id { + if o_id.is_inplace() { + results[o_id.get()].take().unwrap() + } else { + let o_buf = results[o_id.get()].as_ref().unwrap(); + PooledBuffer::new((*o_buf).clone(), pool.clone()) + } + } else { + PooledBuffer::new(pool.borrow_mut().get_buffer(m * n), pool.clone()) + }; + + let a_buf = results[l_id.get()].as_ref().unwrap(); + let b_buf = results[r_id.get()].as_ref().unwrap(); + + T::launch_gemm(a_buf, b_buf, b, m, n, *k, &mut out, *alpha, *beta); + + out + } Op::NoOp => unreachable!("NoOp should not be evaluated."), }; results[idx] = Some(computed); diff --git a/constensor-core/src/cpu_storage/pool.rs b/constensor-core/src/cpu_storage/pool.rs index 2f7e6d1..5db02ec 100644 --- a/constensor-core/src/cpu_storage/pool.rs +++ b/constensor-core/src/cpu_storage/pool.rs @@ -22,6 +22,7 @@ pub struct PoolMetrics { pub drops: usize, } +#[derive(Debug)] /// A simple buffer pool to reuse Vec allocations across graph evaluation. pub struct BufferPool { pool: Vec>, @@ -32,6 +33,7 @@ pub struct BufferPool { /// Shared reference to a BufferPool for automatic recycling. pub type SharedPool = Rc>>; +#[derive(Debug)] /// Wrapper around Vec that returns its buffer to the pool on drop. pub struct PooledBuffer { buf: Vec, diff --git a/constensor-core/src/device.rs b/constensor-core/src/device.rs index a233934..4c181fd 100644 --- a/constensor-core/src/device.rs +++ b/constensor-core/src/device.rs @@ -3,7 +3,7 @@ use crate::cuda_backend::CudaDevice; use crate::{ cpu_storage::CpuDevice, storage::{BackendDevice, Storage}, - DType, Op, Result, Shape, + DType, GraphNode, Result, }; /// Marker trait for devices @@ -66,13 +66,11 @@ pub enum Device { } impl Device { - pub fn compile_and_run_graph(&self, graph: &[Op]) -> Result> { + pub fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result> { match self { #[cfg(feature = "cuda")] Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.compile_and_run_graph::(graph)?)), - Self::Cpu => Ok(Storage::Cpu( - CpuDevice.compile_and_run_graph::(graph)?, - )), + Self::Cpu => Ok(Storage::Cpu(CpuDevice.compile_and_run_graph::(graph)?)), } } } diff --git a/constensor-core/src/dtype/gemm.rs b/constensor-core/src/dtype/gemm.rs new file mode 100644 index 0000000..1bff148 --- /dev/null +++ b/constensor-core/src/dtype/gemm.rs @@ -0,0 +1,234 @@ +use gemm::{gemm, Parallelism}; + +#[cfg(feature = "bfloat")] +use half::bf16; +#[cfg(feature = "half")] +use half::f16; + +pub trait GemmDispatch { + // In bytes, this is also the lane count in bytes + const BLOCK_SIZE: usize = 8; + + #[allow(clippy::too_many_arguments)] + // Matrix multiplication: (B x M x K) * (B x K x N) = (B x M x N) + fn launch_gemm( + lhs: &[Self], + rhs: &[Self], + b: usize, + m: usize, + n: usize, + k: usize, + out: &mut Vec, + alpha: Self, + beta: Self, + ) where + Self: Sized; +} + +macro_rules! instantiate_gemm { + ($rt:ident, $init:expr, NAIVE) => { + impl GemmDispatch for $rt { + fn launch_gemm( + lhs: &[Self], + rhs: &[Self], + b: usize, + m: usize, + n: usize, + k: usize, + out: &mut Vec, + alpha: Self, + beta: Self, + ) where + Self: Sized, + { + for b in 0..b { + for i in 0..m { + for j in 0..n { + let mut sum = $init; + for p in 0..k { + sum += + beta * lhs[b * m * k + i * k + p] * rhs[b * k * n + p * n + j]; + } + out[b * m * n + i * n + j] = alpha * out[b * m * n + i * n + j] + sum; + } + } + } + } + } + }; + + ($rt:ident, $zero:expr, GEMM) => { + impl GemmDispatch for $rt { + fn launch_gemm( + lhs: &[Self], + rhs: &[Self], + b: usize, + m: usize, + n: usize, + k: usize, + out: &mut Vec, + alpha: Self, + beta: Self, + ) where + Self: Sized, + { + let num_threads = num_cpus::get(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; + + // cs = stride[-1], rs = stride[-2] + let dst_cs = 1; + let dst_rs = n; + + let lhs_cs = 1; + let lhs_rs = k; + + let rhs_cs = 1; + let rhs_rs = n; + + let read_dst = alpha != $zero; + + for b in 0..b { + let lhs_p = &lhs[b * m * k..]; + let rhs_p = &rhs[b * k * n..]; + let out_p = &mut out[b * m * n..]; + + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ out_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ read_dst, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ alpha, + /* beta: T = */ beta, + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + parallelism, + ) + } + } + } + } + }; + // SIMD-accelerated gemm using SimdSupported for vectorized operations along 'n' dimension + ($rt:ident, $init:expr, SIMD) => { + impl GemmDispatch for $rt { + fn launch_gemm( + lhs: &[Self], + rhs: &[Self], + b: usize, + m: usize, + n: usize, + k: usize, + out: &mut Vec, + alpha: Self, + beta: Self, + ) where + Self: Sized, + { + use crate::dtype::SimdSupported; + use crate::graph::BinaryOpType; + const BLOCK_SIZE: usize = <$rt as SimdSupported>::BLOCK_SIZE; + let n_blocks = n / BLOCK_SIZE; + let rem = n % BLOCK_SIZE; + + debug_assert_eq!(lhs.len(), b * m * k); + debug_assert_eq!(rhs.len(), b * k * n); + debug_assert_eq!(out.len(), b * m * n); + + for batch in 0..b { + // Compute base pointers once per batch + let lhs_base = unsafe { lhs.as_ptr().add(batch * m * k) }; + let rhs_base = unsafe { rhs.as_ptr().add(batch * k * n) }; + let out_base = unsafe { out.as_mut_ptr().add(batch * m * n) }; + + for i in 0..m { + // Pointer to the start of the current output row + let out_row_ptr = unsafe { out_base.add(i * n) }; + + // Process full SIMD blocks + for block in 0..n_blocks { + let off = block * BLOCK_SIZE; + let out_ptr = unsafe { out_row_ptr.add(off) }; + let out_chunk = + unsafe { std::slice::from_raw_parts_mut(out_ptr, BLOCK_SIZE) }; + + if beta != $init { + let alpha_arr = [alpha; BLOCK_SIZE]; + ::binary_simd_op_inplace_lhs( + out_chunk, + &alpha_arr, + BinaryOpType::Mul, + ); + } else { + for x in out_chunk.iter_mut() { + *x = $init; + } + } + + for p in 0..k { + let a_val = unsafe { *lhs_base.add(i * k + p) }; + let a_arr = [a_val; BLOCK_SIZE]; + let b_ptr = unsafe { rhs_base.add(p * n + off) }; + let b_chunk = + unsafe { std::slice::from_raw_parts(b_ptr, BLOCK_SIZE) }; + ::fma_op_inplace_c( + &a_arr, b_chunk, out_chunk, + ); + } + } + + // Handle remainder elements + if rem > 0 { + let off = n_blocks * BLOCK_SIZE; + let out_ptr = unsafe { out_row_ptr.add(off) }; + let out_chunk = unsafe { std::slice::from_raw_parts_mut(out_ptr, rem) }; + + if beta != $init { + for x in out_chunk.iter_mut() { + *x *= alpha; + } + } else { + for x in out_chunk.iter_mut() { + *x = $init; + } + } + + for p in 0..k { + let a_val = unsafe { *lhs_base.add(i * k + p) }; + for j in 0..rem { + let b_val = unsafe { *rhs_base.add(p * n + off + j) }; + out_chunk[j] += a_val * b_val; + } + } + } + } + } + } + } + }; +} + +instantiate_gemm!(u8, 0, SIMD); +instantiate_gemm!(u32, 0, SIMD); +instantiate_gemm!(i32, 0, SIMD); +instantiate_gemm!(i64, 0, SIMD); +instantiate_gemm!(f32, 0., GEMM); +instantiate_gemm!(f64, 0., GEMM); +#[cfg(feature = "bfloat")] +instantiate_gemm!(bf16, bf16::from_f32(0.), SIMD); +#[cfg(feature = "half")] +instantiate_gemm!(f16, f16::from_f32(0.), GEMM); diff --git a/constensor-core/src/dtype/mod.rs b/constensor-core/src/dtype/mod.rs new file mode 100644 index 0000000..7a39cc6 --- /dev/null +++ b/constensor-core/src/dtype/mod.rs @@ -0,0 +1,229 @@ +use std::{ + fmt::Debug, + ops::{Add, Div, Mul, Sub}, +}; + +use gemm::GemmDispatch; +#[cfg(feature = "bfloat")] +use half::bf16; +#[cfg(feature = "half")] +use half::f16; + +#[cfg(feature = "cuda")] +use cudarc::driver::DeviceRepr; +use simd_ops::SimdSupported; + +mod gemm; +mod simd_ops; + +/// Type which can be square-rooted. +/// If self<0 and Self is integral, then te output is 0 +pub trait Sqrtable { + fn sqrt(&self) -> Self + where + Self: Sized; +} + +impl Sqrtable for f32 { + fn sqrt(&self) -> Self + where + Self: Sized, + { + f32::sqrt(*self) + } +} + +impl Sqrtable for f64 { + fn sqrt(&self) -> Self + where + Self: Sized, + { + f64::sqrt(*self) + } +} + +#[cfg(feature = "bfloat")] +impl Sqrtable for bf16 { + fn sqrt(&self) -> Self + where + Self: Sized, + { + bf16::from_f64_const(self.to_f64_const().sqrt()) + } +} + +#[cfg(feature = "half")] +impl Sqrtable for f16 { + fn sqrt(&self) -> Self + where + Self: Sized, + { + f16::from_f64_const(self.to_f64_const().sqrt()) + } +} + +macro_rules! sqrt_integral { + ($t:ty) => { + impl Sqrtable for $t { + fn sqrt(&self) -> Self + where + Self: Sized, + { + (*self as f64).sqrt() as $t + } + } + }; +} + +sqrt_integral!(u8); +sqrt_integral!(u32); +sqrt_integral!(i32); +sqrt_integral!(i64); + +pub trait DTypeOps: + Copy + + Add + + Div + + Sub + + Mul + + Sqrtable + + SimdSupported + + GemmDispatch +{ +} + +#[cfg(feature = "cuda")] +pub trait DeviceReprLike: DeviceRepr {} + +#[cfg(not(feature = "cuda"))] +pub trait DeviceReprLike {} + +impl DeviceReprLike for u8 {} +impl DeviceReprLike for i32 {} +impl DeviceReprLike for u32 {} +impl DeviceReprLike for i64 {} +impl DeviceReprLike for f32 {} +impl DeviceReprLike for f64 {} + +pub trait MaybeNeg { + const NAME: &'static str; + + /// A fallible version of `neg` that panics on an unsupported type. + fn maybe_neg(self) -> Self; +} + +macro_rules! maybe_neg_failing { + ($rt:ident) => { + impl MaybeNeg for $rt { + const NAME: &'static str = stringify!($rt); + + fn maybe_neg(self) -> Self { + panic!("This type does not support ") + } + } + }; +} + +macro_rules! maybe_neg { + ($rt:ident) => { + impl MaybeNeg for $rt { + const NAME: &'static str = stringify!($rt); + + fn maybe_neg(self) -> Self { + -self + } + } + }; +} + +maybe_neg_failing!(u8); +maybe_neg_failing!(u32); +maybe_neg!(i32); +maybe_neg!(i64); +maybe_neg!(f32); +maybe_neg!(f64); + +#[cfg(not(feature = "cuda"))] +/// Marker trait for tensor datatypes. +pub trait DType: + Debug + Clone + DTypeOps + Send + Sync + MaybeNeg + DeviceReprLike + 'static +{ + const ZERO: Self; + const ONE: Self; + const C_NAME: &'static str; + const C_DEP: Option<&'static str>; + const INTEGRAL: bool; + + fn to_f64(&self) -> f64; + fn from_f64(x: f64) -> Self; +} + +macro_rules! dtype { + ($rt:ident, $zero:expr, $one:expr, $c_repr:expr, $integral:expr) => { + impl DTypeOps for $rt {} + impl DType for $rt { + const ZERO: $rt = $zero; + const ONE: $rt = $one; + const C_NAME: &'static str = $c_repr; + const C_DEP: Option<&'static str> = None; + const INTEGRAL: bool = $integral; + + fn to_f64(&self) -> f64 { + *self as f64 + } + fn from_f64(x: f64) -> Self { + x as $rt + } + } + }; +} + +dtype!(u8, 0u8, 1u8, "uint8_t", true); +dtype!(u32, 0u32, 1u32, "uint32_t", true); +dtype!(i32, 0i32, 1i32, "int", true); +dtype!(i64, 0i64, 1i64, "int64_t", true); +dtype!(f32, 0f32, 1f32, "float", false); +dtype!(f64, 0f64, 1f64, "double", false); + +#[cfg(feature = "half")] +impl DTypeOps for f16 {} +#[cfg(feature = "half")] +impl DeviceReprLike for f16 {} +#[cfg(feature = "half")] +maybe_neg!(f16); +#[cfg(feature = "half")] +impl DType for f16 { + const ZERO: f16 = f16::from_f64_const(0.0); + const ONE: f16 = f16::from_f64_const(1.0); + const C_NAME: &'static str = "__half"; + const C_DEP: Option<&'static str> = Some("#include \"cuda_fp16.h\""); + const INTEGRAL: bool = false; + + fn to_f64(&self) -> f64 { + self.to_f64_const() + } + fn from_f64(x: f64) -> Self { + Self::from_f64_const(x) + } +} +#[cfg(feature = "bfloat")] +impl DTypeOps for bf16 {} +#[cfg(feature = "bfloat")] +impl DeviceReprLike for bf16 {} +#[cfg(feature = "bfloat")] +maybe_neg!(bf16); +#[cfg(feature = "bfloat")] +impl DType for bf16 { + const ZERO: bf16 = bf16::from_f64_const(0.0); + const ONE: bf16 = bf16::from_f64_const(1.0); + const C_NAME: &'static str = "__nv_bfloat16"; + const C_DEP: Option<&'static str> = Some("#include \"cuda_bf16.h\""); + const INTEGRAL: bool = false; + + fn to_f64(&self) -> f64 { + self.to_f64_const() + } + fn from_f64(x: f64) -> Self { + Self::from_f64_const(x) + } +} diff --git a/constensor-core/src/dtype.rs b/constensor-core/src/dtype/simd_ops.rs similarity index 62% rename from constensor-core/src/dtype.rs rename to constensor-core/src/dtype/simd_ops.rs index 6f72358..923c6f6 100644 --- a/constensor-core/src/dtype.rs +++ b/constensor-core/src/dtype/simd_ops.rs @@ -1,227 +1,10 @@ -use std::{ - fmt::Debug, - ops::{Add, Div, Mul, Sub}, -}; - #[cfg(feature = "bfloat")] use half::bf16; #[cfg(feature = "half")] use half::f16; -#[cfg(feature = "cuda")] -use cudarc::driver::DeviceRepr; - use crate::graph::BinaryOpType; -/// Type which can be square-rooted. -/// If self<0 and Self is integral, then te output is 0 -pub trait Sqrtable { - fn sqrt(&self) -> Self - where - Self: Sized; -} - -impl Sqrtable for f32 { - fn sqrt(&self) -> Self - where - Self: Sized, - { - f32::sqrt(*self) - } -} - -impl Sqrtable for f64 { - fn sqrt(&self) -> Self - where - Self: Sized, - { - f64::sqrt(*self) - } -} - -#[cfg(feature = "bfloat")] -impl Sqrtable for bf16 { - fn sqrt(&self) -> Self - where - Self: Sized, - { - bf16::from_f64_const(self.to_f64_const().sqrt()) - } -} - -#[cfg(feature = "half")] -impl Sqrtable for f16 { - fn sqrt(&self) -> Self - where - Self: Sized, - { - f16::from_f64_const(self.to_f64_const().sqrt()) - } -} - -macro_rules! sqrt_integral { - ($t:ty) => { - impl Sqrtable for $t { - fn sqrt(&self) -> Self - where - Self: Sized, - { - (*self as f64).sqrt() as $t - } - } - }; -} - -sqrt_integral!(u8); -sqrt_integral!(u32); -sqrt_integral!(i32); -sqrt_integral!(i64); - -pub trait DTypeOps: - Copy - + Add - + Div - + Sub - + Mul - + Sqrtable - + SimdSupported -{ -} - -#[cfg(feature = "cuda")] -pub trait DeviceReprLike: DeviceRepr {} - -#[cfg(not(feature = "cuda"))] -pub trait DeviceReprLike {} - -impl DeviceReprLike for u8 {} -impl DeviceReprLike for i32 {} -impl DeviceReprLike for u32 {} -impl DeviceReprLike for i64 {} -impl DeviceReprLike for f32 {} -impl DeviceReprLike for f64 {} - -pub trait MaybeNeg { - const NAME: &'static str; - - /// A fallible version of `neg` that panics on an unsupported type. - fn maybe_neg(self) -> Self; -} - -macro_rules! maybe_neg_failing { - ($rt:ident) => { - impl MaybeNeg for $rt { - const NAME: &'static str = stringify!($rt); - - fn maybe_neg(self) -> Self { - panic!("This type does not support ") - } - } - }; -} - -macro_rules! maybe_neg { - ($rt:ident) => { - impl MaybeNeg for $rt { - const NAME: &'static str = stringify!($rt); - - fn maybe_neg(self) -> Self { - -self - } - } - }; -} - -maybe_neg_failing!(u8); -maybe_neg_failing!(u32); -maybe_neg!(i32); -maybe_neg!(i64); -maybe_neg!(f32); -maybe_neg!(f64); - -#[cfg(not(feature = "cuda"))] -/// Marker trait for tensor datatypes. -pub trait DType: Debug + Clone + DTypeOps + Send + Sync + MaybeNeg + DeviceReprLike { - const ZERO: Self; - const ONE: Self; - const C_NAME: &'static str; - const C_DEP: Option<&'static str>; - const INTEGRAL: bool; - - fn to_f64(&self) -> f64; - fn from_f64(x: f64) -> Self; -} - -macro_rules! dtype { - ($rt:ident, $zero:expr, $one:expr, $c_repr:expr, $integral:expr) => { - impl DTypeOps for $rt {} - impl DType for $rt { - const ZERO: $rt = $zero; - const ONE: $rt = $one; - const C_NAME: &'static str = $c_repr; - const C_DEP: Option<&'static str> = None; - const INTEGRAL: bool = $integral; - - fn to_f64(&self) -> f64 { - *self as f64 - } - fn from_f64(x: f64) -> Self { - x as $rt - } - } - }; -} - -dtype!(u8, 0u8, 1u8, "uint8_t", true); -dtype!(u32, 0u32, 1u32, "uint32_t", true); -dtype!(i32, 0i32, 1i32, "int", true); -dtype!(i64, 0i64, 1i64, "int64_t", true); -dtype!(f32, 0f32, 1f32, "float", false); -dtype!(f64, 0f64, 1f64, "double", false); - -#[cfg(feature = "half")] -impl DTypeOps for f16 {} -#[cfg(feature = "half")] -impl DeviceReprLike for f16 {} -#[cfg(feature = "half")] -maybe_neg!(f16); -#[cfg(feature = "half")] -impl DType for f16 { - const ZERO: f16 = f16::from_f64_const(0.0); - const ONE: f16 = f16::from_f64_const(1.0); - const C_NAME: &'static str = "__half"; - const C_DEP: Option<&'static str> = Some("#include \"cuda_fp16.h\""); - const INTEGRAL: bool = false; - - fn to_f64(&self) -> f64 { - self.to_f64_const() - } - fn from_f64(x: f64) -> Self { - Self::from_f64_const(x) - } -} -#[cfg(feature = "bfloat")] -impl DTypeOps for bf16 {} -#[cfg(feature = "bfloat")] -impl DeviceReprLike for bf16 {} -#[cfg(feature = "bfloat")] -maybe_neg!(bf16); -#[cfg(feature = "bfloat")] -impl DType for bf16 { - const ZERO: bf16 = bf16::from_f64_const(0.0); - const ONE: bf16 = bf16::from_f64_const(1.0); - const C_NAME: &'static str = "__nv_bfloat16"; - const C_DEP: Option<&'static str> = Some("#include \"cuda_bf16.h\""); - const INTEGRAL: bool = false; - - fn to_f64(&self) -> f64 { - self.to_f64_const() - } - fn from_f64(x: f64) -> Self { - Self::from_f64_const(x) - } -} - pub trait SimdSupported { // In bytes, this is also the lane count in bytes const BLOCK_SIZE: usize = 8; @@ -286,8 +69,13 @@ macro_rules! simd_supported { // Vectorized loop for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let l_chunk = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let r_chunk = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let l_chunk: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let r_chunk: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; let res = simd_op(l_chunk, r_chunk); out[off..(off + Self::BLOCK_SIZE).min(len)].copy_from_slice(res.as_array()); } @@ -326,8 +114,13 @@ macro_rules! simd_supported { // Vectorized loop for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let l_chunk = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let r_chunk = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let l_chunk: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let r_chunk: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; let res = simd_op(l_chunk, r_chunk); a[off..(off + Self::BLOCK_SIZE).min(len)].copy_from_slice(res.as_array()); } @@ -366,8 +159,13 @@ macro_rules! simd_supported { // Vectorized loop for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let l_chunk = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let r_chunk = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let l_chunk: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let r_chunk: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; let res = simd_op(l_chunk, r_chunk); b[off..(off + Self::BLOCK_SIZE).min(len)].copy_from_slice(res.as_array()); } @@ -392,9 +190,16 @@ macro_rules! simd_supported { use std::simd::StdFloat; for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let a = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let b = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); - let c = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&c[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let a: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let b: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; + let c: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(c.as_ptr().add(off) as *const _) + }; let res = a.mul_add(b, c); out[off..off + Self::BLOCK_SIZE].copy_from_slice(res.as_array()); } @@ -413,9 +218,16 @@ macro_rules! simd_supported { use std::simd::StdFloat; for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let ax = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let b = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); - let c = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&c[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let ax: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let b: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; + let c: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(c.as_ptr().add(off) as *const _) + }; let res = ax.mul_add(b, c); a[off..off + Self::BLOCK_SIZE].copy_from_slice(res.as_array()); } @@ -434,9 +246,16 @@ macro_rules! simd_supported { use std::simd::StdFloat; for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let a = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let bx = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); - let c = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&c[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let a: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let bx: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; + let c: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(c.as_ptr().add(off) as *const _) + }; let res = a.mul_add(bx, c); b[off..off + Self::BLOCK_SIZE].copy_from_slice(res.as_array()); } @@ -455,9 +274,16 @@ macro_rules! simd_supported { use std::simd::StdFloat; for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let a = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let b = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); - let cx = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&c[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let a: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let b: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; + let cx: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(c.as_ptr().add(off) as *const _) + }; let res = a.mul_add(b, cx); c[off..off + Self::BLOCK_SIZE].copy_from_slice(res.as_array()); } @@ -480,9 +306,16 @@ macro_rules! simd_supported { for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let a = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let b = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); - let c = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&c[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let a: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let b: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; + let c: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(c.as_ptr().add(off) as *const _) + }; let res = a * b + c; out[off..off + Self::BLOCK_SIZE].copy_from_slice(res.as_array()); } @@ -500,9 +333,16 @@ macro_rules! simd_supported { for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let ax = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let b = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); - let c = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&c[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let ax: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let b: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; + let c: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(c.as_ptr().add(off) as *const _) + }; let res = ax * b + c; a[off..off + Self::BLOCK_SIZE].copy_from_slice(res.as_array()); } @@ -520,9 +360,16 @@ macro_rules! simd_supported { for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let a = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let bx = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); - let c = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&c[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let a: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let bx: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; + let c: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(c.as_ptr().add(off) as *const _) + }; let res = a * bx + c; b[off..off + Self::BLOCK_SIZE].copy_from_slice(res.as_array()); } @@ -540,9 +387,16 @@ macro_rules! simd_supported { for i in 0..n_blocks { let off = i * Self::BLOCK_SIZE; - let a = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&a[off..off + Self::BLOCK_SIZE]); - let b = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&b[off..off + Self::BLOCK_SIZE]); - let cx = std::simd::Simd::<$t, { Self::BLOCK_SIZE }>::from_slice(&c[off..off + Self::BLOCK_SIZE]); + // SAFETY: the invariant is upheld with the loop condition + let a: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(a.as_ptr().add(off) as *const _) + }; + let b: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(b.as_ptr().add(off) as *const _) + }; + let cx: std::simd::Simd<$t, { Self::BLOCK_SIZE }> = unsafe { + std::ptr::read_unaligned(c.as_ptr().add(off) as *const _) + }; let res = a * b + cx; c[off..off + Self::BLOCK_SIZE].copy_from_slice(res.as_array()); } diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 12ff37e..f71f229 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -4,23 +4,27 @@ use std::{ env, fmt::Display, fs, + hash::Hash, path::Path, process::Command, rc::Rc, sync::{Arc, RwLock, RwLockReadGuard}, }; -use crate::{DType, Result}; +use crate::{DType, Result, Shape}; use petgraph::Graph as PetGraph; -use petgraph::{ - dot::{Config, Dot}, - graph::NodeIndex, -}; +use petgraph::{dot::Dot, graph::NodeIndex}; + +#[derive(Clone, Debug)] +pub struct GraphNode { + pub op: Op, + pub shape: Vec, +} #[derive(Clone)] pub struct Graph { - data: Arc>>>, + data: Arc>>>, id: Arc>, } @@ -34,51 +38,51 @@ impl Graph { } /// Read-only access to the list of operations - pub fn get_ops(&self) -> RwLockReadGuard>> { + pub fn get_ops(&self) -> RwLockReadGuard>> { self.data.read().unwrap() } /// Append an operation to the graph - pub(crate) fn add_op(&self, op: Op) { - self.data.write().unwrap().push(op); + pub(crate) fn add_op(&self, op: Op) { + self.data.write().unwrap().push(GraphNode { + op, + shape: S::shape(), + }); } /// Generate the next unique tensor ID #[must_use] pub(crate) fn next_id(&mut self) -> GraphTensorId { - let next = GraphTensorId::from(*self.id.read().unwrap()); + let next = GraphTensorId::out_of_place(*self.id.read().unwrap()); *self.id.write().unwrap() += 1; next } - pub fn to_petgraph(&self) -> PetGraph { + pub fn to_petgraph(&self) -> PetGraph { let ops = self.data.read().unwrap(); - let mut g = PetGraph::::new(); + let mut g = PetGraph::::new(); // map from op‐index → Some(node) if we created a node, or None if it was a NoOp let mut idx_map: Vec> = Vec::with_capacity(ops.len()); // 1) Add only non‐NoOp nodes for op in ops.iter() { - match op { + match op.op { Op::NoOp => { idx_map.push(None); } _ => { - let label = match op { - Op::Fill { v } => format!("Fill({:?})", v), - Op::Arange { start, step, stop } => { - format!( - "Arange(start={:?}, step={:?}, stop={:?})", - start, step, stop - ) + let label = match &op.op { + Op::Fill { v, .. } => format!("Fill({v:?})"), + Op::Arange { + start, step, stop, .. + } => { + format!("Arange(start={start:?}, step={step:?}, stop={stop:?})") } Op::BinaryOp { operator, .. } => format!("BinOp({})", operator.as_c_op()), - Op::InplaceBinaryOp { operator, .. } => { - format!("InplaceBinOp({})", operator.as_c_op()) - } - Op::UnaryOp { operator, .. } => format!("UnOp({:?})", operator), + Op::UnaryOp { operator, .. } => format!("UnOp({operator:?})"), Op::FusedMulAdd { .. } => "FMA".to_string(), - Op::InplaceFusedMulAdd { .. } => "InplaceFMA".to_string(), + // Matrix multiplication + Op::MatMul { .. } => "MatMul".to_string(), // we already matched NoOp above Op::NoOp => unreachable!(), }; @@ -95,27 +99,69 @@ impl Graph { Some(dst) => dst, None => continue, }; - match op { - Op::BinaryOp { l_id, r_id, .. } | Op::InplaceBinaryOp { l_id, r_id, .. } => { - if let Some(src) = idx_map[usize::from(l_id)] { - g.add_edge(src, dst, ()); + match &op.op { + Op::BinaryOp { l_id, r_id, .. } => { + if let Some(src) = idx_map[l_id.get()] { + let mut label = "l".to_string(); + if l_id.is_inplace() { + label.push('*'); + } + g.add_edge(src, dst, label.clone()); } - if let Some(src) = idx_map[usize::from(r_id)] { - g.add_edge(src, dst, ()); + if let Some(src) = idx_map[r_id.get()] { + let mut label = "r".to_string(); + if r_id.is_inplace() { + label.push('*'); + } + g.add_edge(src, dst, label.clone()); } } Op::UnaryOp { v_id, .. } => { - if let Some(src) = idx_map[usize::from(v_id)] { - g.add_edge(src, dst, ()); + if let Some(src) = idx_map[v_id.get()] { + let mut label = "v".to_string(); + if v_id.is_inplace() { + label.push('*'); + } + g.add_edge(src, dst, label.clone()); } } - Op::FusedMulAdd { a_id, b_id, c_id } - | Op::InplaceFusedMulAdd { + Op::FusedMulAdd { a_id, b_id, c_id, .. } => { - for src_id in [a_id, b_id, c_id] { - if let Some(src) = idx_map[usize::from(src_id)] { - g.add_edge(src, dst, ()); + for (prefix, src_id) in [("a", a_id), ("b", b_id), ("c", c_id)].iter() { + if let Some(src) = idx_map[src_id.get()] { + let mut label = prefix.to_string(); + if src_id.is_inplace() { + label.push('*'); + } + g.add_edge(src, dst, label.clone()); + } + } + } + Op::MatMul { + l_id, r_id, o_id, .. + } => { + if let Some(src) = idx_map[l_id.get()] { + let mut label = "l".to_string(); + if l_id.is_inplace() { + label.push('*'); + } + g.add_edge(src, dst, label.clone()); + } + if let Some(src) = idx_map[r_id.get()] { + let mut label = "r".to_string(); + if r_id.is_inplace() { + label.push('*'); + } + g.add_edge(src, dst, label.clone()); + } + if let Some(o_id) = o_id { + if let Some(src) = idx_map[o_id.get()] { + let mut label = "o".to_string(); + if o_id.is_inplace() { + label.push('*'); + } + g.add_edge(src, dst, label.clone()); } } } @@ -130,7 +176,7 @@ impl Graph { /// Produce a DOT format string of this graph. pub fn to_dot(&self) -> String { let g = self.to_petgraph(); - format!("{:?}", Dot::with_config(&g, &[Config::EdgeNoLabel])) + format!("{:?}", Dot::with_config(&g, &[])) } /// Visualize the graph by saving it to this file. @@ -171,60 +217,64 @@ impl Graph { l_id: a_id, r_id: b_id, operator: BinaryOpType::Mul, - } = x + } = &x.op { // Check if next op uses this if let Op::BinaryOp { l_id: l_y, r_id: r_y, operator: BinaryOpType::Add, - } = &ops[x_id + 1] + } = &ops[x_id + 1].op { let y_id = x_id + 1; - if <&GraphTensorId as Into>::into(l_y) == x_id - || <&GraphTensorId as Into>::into(r_y) == x_id - { + if l_y.get() == x_id || r_y.get() == x_id && x.shape == ops[x_id + 1].shape { // Want to see what is being added to the result of the mul - let rhs_add = if <&GraphTensorId as Into>::into(l_y) == x_id { - r_y - } else { - l_y + let rhs_add = if l_y.get() == x_id { r_y } else { l_y }; + new_ops[y_id] = GraphNode { + op: Op::FusedMulAdd { + a_id: a_id.clone(), + b_id: b_id.clone(), + c_id: rhs_add.clone(), + }, + shape: x.shape.clone(), }; - new_ops[y_id] = Op::FusedMulAdd { - a_id: GraphTensorId::from(a_id.0.get()), - b_id: GraphTensorId::from(b_id.0.get()), - c_id: GraphTensorId::from(rhs_add.0.get()), + new_ops[x_id] = GraphNode { + op: Op::NoOp, + shape: x.shape.clone(), }; - new_ops[x_id] = Op::NoOp; // Look for ops which actually use this one - for user in ops.iter() { - let ids = match user { + for user in new_ops.iter() { + let ids = match &user.op { Op::Arange { start: _, step: _, stop: _, + .. } => vec![], - Op::BinaryOp { l_id, r_id, .. } - | Op::InplaceBinaryOp { l_id, r_id, .. } => vec![l_id, r_id], - Op::Fill { v: _ } => vec![], - Op::UnaryOp { v_id, operator: _ } => vec![v_id], - Op::FusedMulAdd { a_id, b_id, c_id } - | Op::InplaceFusedMulAdd { + Op::BinaryOp { l_id, r_id, .. } => vec![l_id, r_id], + Op::Fill { v: _, .. } => vec![], + Op::UnaryOp { + v_id, operator: _, .. + } => vec![v_id], + Op::FusedMulAdd { a_id, b_id, c_id, .. } => { vec![a_id, b_id, c_id] } + Op::MatMul { l_id, r_id, .. } => vec![l_id, r_id], Op::NoOp => vec![], }; + + // We are going to remove the noop so this is necessary to fix the indices. let used_ids = ids .into_iter() - .filter(|id| <&GraphTensorId as Into>::into(id) != y_id) + .filter(|id| id.get() == y_id) .collect::>(); if !used_ids.is_empty() { for id in used_ids { // Tell the ops which use the result of the fma to source from there - id.0.set(y_id); + id.set(x_id); } } } @@ -236,30 +286,35 @@ impl Graph { // Remove any NoOp entries before storing back to the graph let filtered_ops = new_ops .into_iter() - .filter(|op| !matches!(op, Op::NoOp)) + .filter(|op| !matches!(op.op, Op::NoOp)) .collect::>(); *self.data.write().unwrap() = filtered_ops; } /// Count how often each tensor id is used as an input. - fn count_input_usage(ops: &[Op]) -> HashMap { - let mut usage: HashMap = HashMap::new(); + #[allow(clippy::mutable_key_type)] + fn count_input_usage(ops: &[GraphNode]) -> HashMap { + #[allow(clippy::mutable_key_type)] + let mut usage: HashMap = HashMap::new(); for op in ops { - match op { - Op::BinaryOp { l_id, r_id, .. } | Op::InplaceBinaryOp { l_id, r_id, .. } => { - *usage.entry(usize::from(l_id)).or_default() += 1; - *usage.entry(usize::from(r_id)).or_default() += 1; + match &op.op { + Op::BinaryOp { l_id, r_id, .. } => { + *usage.entry(l_id.clone()).or_default() += 1; + *usage.entry(r_id.clone()).or_default() += 1; } Op::UnaryOp { v_id, .. } => { - *usage.entry(usize::from(v_id)).or_default() += 1; + *usage.entry(v_id.clone()).or_default() += 1; } - Op::FusedMulAdd { a_id, b_id, c_id } - | Op::InplaceFusedMulAdd { + Op::FusedMulAdd { a_id, b_id, c_id, .. } => { - *usage.entry(usize::from(a_id)).or_default() += 1; - *usage.entry(usize::from(b_id)).or_default() += 1; - *usage.entry(usize::from(c_id)).or_default() += 1; + *usage.entry(a_id.clone()).or_default() += 1; + *usage.entry(b_id.clone()).or_default() += 1; + *usage.entry(c_id.clone()).or_default() += 1; + } + Op::MatMul { l_id, r_id, .. } => { + *usage.entry(l_id.clone()).or_default() += 1; + *usage.entry(r_id.clone()).or_default() += 1; } Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} } @@ -271,6 +326,7 @@ impl Graph { fn optimize_inplace_bin(&mut self) { let ops = self.data.write().unwrap().clone(); let mut new_ops = ops.clone(); + #[allow(clippy::mutable_key_type)] let usage = Self::count_input_usage(&ops); // Transform eligible BinaryOps into InplaceBinaryOps. for (i, op) in ops.iter().enumerate() { @@ -278,13 +334,11 @@ impl Graph { l_id, r_id, operator, - } = op + } = &op.op { - let l_idx = usize::from(l_id); - let r_idx = usize::from(r_id); - let l_use = usage.get(&l_idx).copied().unwrap_or(0); - let r_use = usage.get(&r_idx).copied().unwrap_or(0); - if l_use == 1 || r_use == 1 { + let l_use = usage.get(l_id).copied().unwrap_or(0); + let r_use = usage.get(r_id).copied().unwrap_or(0); + if l_use <= 1 || r_use <= 1 { // Choose target for in-place: if both, default to lhs. let target = if r_use > l_use { r_id.clone() @@ -292,46 +346,14 @@ impl Graph { l_id.clone() }; // Replace with InplaceBinaryOp - new_ops[i] = Op::InplaceBinaryOp { - out: target.clone(), - l_id: l_id.clone(), - r_id: r_id.clone(), - operator: *operator, + new_ops[i] = GraphNode { + op: Op::BinaryOp { + l_id: l_id.clone().to_inplace_if(&target == l_id), + r_id: r_id.clone().to_inplace_if(&target == r_id), + operator: *operator, + }, + shape: op.shape.clone(), }; - // Update all future uses of this op's result (index i) to use 'target'. - for fut in new_ops.iter_mut().skip(i + 1) { - match fut { - Op::BinaryOp { l_id, r_id, .. } - | Op::InplaceBinaryOp { l_id, r_id, .. } => { - if usize::from(&*l_id) == i { - l_id.0.set(usize::from(&target)); - } - if usize::from(&*r_id) == i { - r_id.0.set(usize::from(&target)); - } - } - Op::UnaryOp { v_id, .. } => { - if usize::from(&*v_id) == i { - v_id.0.set(usize::from(&target)); - } - } - Op::FusedMulAdd { a_id, b_id, c_id } - | Op::InplaceFusedMulAdd { - a_id, b_id, c_id, .. - } => { - if usize::from(&*a_id) == i { - a_id.0.set(usize::from(&target)); - } - if usize::from(&*b_id) == i { - b_id.0.set(usize::from(&target)); - } - if usize::from(&*c_id) == i { - c_id.0.set(usize::from(&target)); - } - } - Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} - } - } } } } @@ -343,62 +365,69 @@ impl Graph { fn optimize_inplace_fma(&mut self) { let ops = self.data.write().unwrap().clone(); let mut new_ops = ops.clone(); + #[allow(clippy::mutable_key_type)] let usage = Self::count_input_usage(&ops); for (i, op) in ops.iter().enumerate() { - if let Op::FusedMulAdd { a_id, b_id, c_id } = op { + if let Op::FusedMulAdd { a_id, b_id, c_id } = &op.op { let mut target = None; // If an input is used only once, we can reuse its buffer; default order: a_id, then b_id, then c_id - if *usage.get(&usize::from(a_id)).unwrap_or(&0) == 1 { + if *usage.get(a_id).unwrap_or(&0) <= 1 { target = Some(a_id.clone()); - } else if *usage.get(&usize::from(b_id)).unwrap_or(&0) == 1 { + } else if *usage.get(b_id).unwrap_or(&0) <= 1 { target = Some(b_id.clone()); - } else if *usage.get(&usize::from(c_id)).unwrap_or(&0) == 1 { + } else if *usage.get(c_id).unwrap_or(&0) <= 1 { target = Some(c_id.clone()); } if let Some(out) = target { - new_ops[i] = Op::InplaceFusedMulAdd { - out: out.clone(), - a_id: a_id.clone(), - b_id: b_id.clone(), - c_id: c_id.clone(), + new_ops[i] = GraphNode { + op: Op::FusedMulAdd { + a_id: a_id.clone().to_inplace_if(&out == a_id), + b_id: b_id.clone().to_inplace_if(&out == b_id), + c_id: c_id.clone().to_inplace_if(&out == c_id), + }, + shape: op.shape.clone(), + }; + } + } + } + *self.data.write().unwrap() = new_ops; + } + + /// Optimize by inplacing the output of a matmul when inputs are not reused. + fn optimize_inplace_matmul(&mut self) { + let ops = self.data.write().unwrap().clone(); + let mut new_ops = ops.clone(); + #[allow(clippy::mutable_key_type)] + let usage = Self::count_input_usage(&ops); + // Transform eligible BinaryOps into InplaceBinaryOps. + for (i, op) in ops.iter().enumerate() { + if let Op::MatMul { + o_id: Some(o_id), + l_id, + r_id, + k, + alpha, + beta, + } = &op.op + { + let o_use = usage.get(o_id).copied().unwrap_or(0); + if o_use <= 1 { + // Replace with InplaceBinaryOp + new_ops[i] = GraphNode { + op: Op::MatMul { + o_id: Some(o_id.to_inplace()), + l_id: l_id.clone(), + r_id: r_id.clone(), + k: *k, + alpha: *alpha, + beta: *beta, + }, + shape: op.shape.clone(), }; - // Update all future ops that reference the original index i - for fut in new_ops.iter_mut().skip(i + 1) { - match fut { - Op::BinaryOp { l_id, r_id, .. } - | Op::InplaceBinaryOp { l_id, r_id, .. } => { - if usize::from(&*l_id) == i { - l_id.0.set(usize::from(&out)); - } - if usize::from(&*r_id) == i { - r_id.0.set(usize::from(&out)); - } - } - Op::UnaryOp { v_id, .. } => { - if usize::from(&*v_id) == i { - v_id.0.set(usize::from(&out)); - } - } - Op::FusedMulAdd { a_id, b_id, c_id } - | Op::InplaceFusedMulAdd { - a_id, b_id, c_id, .. - } => { - if usize::from(&*a_id) == i { - a_id.0.set(usize::from(&out)); - } - if usize::from(&*b_id) == i { - b_id.0.set(usize::from(&out)); - } - if usize::from(&*c_id) == i { - c_id.0.set(usize::from(&out)); - } - } - Op::NoOp | Op::Fill { .. } | Op::Arange { .. } => {} - } - } } } } + // Commit the transformed op list. *self.data.write().unwrap() = new_ops; } @@ -410,6 +439,7 @@ impl Graph { self.optimize_fma(); self.optimize_inplace_bin(); self.optimize_inplace_fma(); + self.optimize_inplace_matmul(); } } @@ -478,12 +508,6 @@ pub enum Op { r_id: GraphTensorId, operator: BinaryOpType, }, - InplaceBinaryOp { - out: GraphTensorId, - l_id: GraphTensorId, - r_id: GraphTensorId, - operator: BinaryOpType, - }, UnaryOp { v_id: GraphTensorId, operator: UnaryOpType, @@ -494,34 +518,67 @@ pub enum Op { b_id: GraphTensorId, c_id: GraphTensorId, }, - /// a * b + c - InplaceFusedMulAdd { - out: GraphTensorId, - a_id: GraphTensorId, - b_id: GraphTensorId, - c_id: GraphTensorId, + /// (B x M x K) * (B x K x N) = (B x M x N) + /// out = out * alpha + beta * lhs * rhs + MatMul { + l_id: GraphTensorId, + r_id: GraphTensorId, + o_id: Option, + k: usize, + alpha: T, + beta: T, }, NoOp, } -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, PartialEq, Debug, Eq)] /// Graph tensor IDs can be cloned. -pub struct GraphTensorId(Rc>); +pub enum GraphTensorId { + OutOfPlace(Rc>), + InPlace(Rc>), +} -impl From for usize { - fn from(value: GraphTensorId) -> Self { - value.0.get() +impl Hash for GraphTensorId { + fn hash(&self, state: &mut H) { + state.write_usize(self.get()); } } -impl From<&GraphTensorId> for usize { - fn from(value: &GraphTensorId) -> Self { - value.0.get() +impl GraphTensorId { + pub fn out_of_place(value: usize) -> Self { + Self::OutOfPlace(Rc::new(Cell::new(value))) + } + + pub fn inplace(value: usize) -> Self { + Self::InPlace(Rc::new(Cell::new(value))) + } + + pub fn to_inplace(&self) -> Self { + match self { + Self::OutOfPlace(x) | Self::InPlace(x) => Self::inplace(x.get()), + } + } + + pub fn to_inplace_if(&self, predicate: bool) -> Self { + match self { + Self::OutOfPlace(x) | Self::InPlace(x) if predicate => Self::inplace(x.get()), + _ => self.clone(), + } + } + + pub fn get(&self) -> usize { + match self { + GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.get(), + } + } + + pub fn set(&self, value: usize) { + match self { + GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.set(value), + } } -} -impl From for GraphTensorId { - fn from(value: usize) -> Self { - Self(Rc::new(Cell::new(value))) + pub fn is_inplace(&self) -> bool { + matches!(self, Self::InPlace(_)) } } diff --git a/constensor-core/src/lib.rs b/constensor-core/src/lib.rs index 31b530e..c8af30a 100644 --- a/constensor-core/src/lib.rs +++ b/constensor-core/src/lib.rs @@ -16,6 +16,6 @@ pub use device::Cpu; pub use device::Cuda; pub use dtype::DType; pub use error::{Error, Result}; -pub use graph::{Graph, Op}; +pub use graph::{Graph, GraphNode, Op}; pub use shape::{Shape, R1, R2, R3, R4, R5, R6}; pub use tensor::{GraphTensor, Tensor}; diff --git a/constensor-core/src/storage.rs b/constensor-core/src/storage.rs index 12af78c..2d10bf6 100644 --- a/constensor-core/src/storage.rs +++ b/constensor-core/src/storage.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; #[cfg(feature = "cuda")] use crate::cuda_backend::CudaStorage; -use crate::{cpu_storage::CpuStorage, DType, Op, Result, Shape}; +use crate::{cpu_storage::CpuStorage, DType, GraphNode, Result}; pub enum Storage { #[cfg(feature = "cuda")] @@ -27,8 +27,5 @@ pub trait BackendStorage { pub trait BackendDevice { type Storage: BackendStorage; - fn compile_and_run_graph( - &self, - graph: &[Op], - ) -> Result>; + fn compile_and_run_graph(&self, graph: &[GraphNode]) -> Result>; } diff --git a/constensor-core/src/tensor/graphtensor.rs b/constensor-core/src/tensor/graphtensor.rs index fbd0756..6269718 100644 --- a/constensor-core/src/tensor/graphtensor.rs +++ b/constensor-core/src/tensor/graphtensor.rs @@ -8,7 +8,7 @@ use crate::{ device::Dev, graph::{BinaryOpType, Graph, GraphTensorId, Op, UnaryOpType}, tensor::concretetensor::from_storage, - DType, Result, Shape, Tensor, R1, + DType, Result, Shape, Tensor, R1, R3, }; /// A tensor representing an intermediary result of a graph. Performing operations @@ -20,12 +20,68 @@ pub struct GraphTensor { _ghost: PhantomData<(S, T, D)>, } +impl + GraphTensor, T, D> +{ + #[must_use] + // Matrix multiplication: (B x M x K) * (B x K x N) = (B x M x N) + pub fn matmul( + self, + rhs: GraphTensor, T, D>, + ) -> GraphTensor, T, D> { + self.graph + .write() + .unwrap() + .add_op::>(Op::MatMul { + l_id: self.id(), + r_id: rhs.id(), + o_id: None, + k: K, + alpha: T::ZERO, + beta: T::ONE, + }); + GraphTensor { + id: self.graph.write().unwrap().next_id(), + graph: self.graph.clone(), + _ghost: PhantomData, + } + } + + #[must_use] + // Matrix multiplication: (B x M x K) * (B x K x N) = (B x M x N) + /// out = out * alpha + beta * lhs * rhs + pub fn matmul_axpby( + self, + rhs: GraphTensor, T, D>, + out: GraphTensor, T, D>, + alpha: T, + beta: T, + ) -> GraphTensor, T, D> { + self.graph + .write() + .unwrap() + .add_op::>(Op::MatMul { + l_id: self.id(), + r_id: rhs.id(), + o_id: Some(out.id()), + k: K, + alpha, + beta, + }); + GraphTensor { + id: self.graph.write().unwrap().next_id(), + graph: self.graph.clone(), + _ghost: PhantomData, + } + } +} + impl GraphTensor { #[must_use] /// Create a tensor filled with some value. pub fn fill(graph: &mut Graph, v: T) -> Self { let id = graph.next_id(); - graph.add_op(Op::Fill { v }); + graph.add_op::(Op::Fill { v }); Self { id, graph: Arc::new(RwLock::new(graph.clone())), @@ -48,7 +104,7 @@ impl GraphTensor { #[must_use] /// Elementwise unary square root. pub fn sqrt(self) -> GraphTensor { - self.graph.write().unwrap().add_op(Op::UnaryOp { + self.graph.write().unwrap().add_op::(Op::UnaryOp { v_id: self.id(), operator: UnaryOpType::Sqrt, }); @@ -77,7 +133,7 @@ impl GraphTensor { let nodes = &*graph.get_ops(); let device = D::resolve()?; - let storage = device.compile_and_run_graph::(nodes)?; + let storage = device.compile_and_run_graph::(nodes)?; Ok(from_storage(Arc::new(storage))) } } @@ -88,7 +144,7 @@ impl GraphTensor, T, D> { pub fn arange(graph: &mut Graph, start: T, stop: T) -> Self { let id = graph.next_id(); let step = (stop.to_f64() - start.to_f64()) / (A as f64); - graph.add_op(Op::Arange { + graph.add_op::>(Op::Arange { start, step: T::from_f64(step), stop, @@ -107,7 +163,7 @@ macro_rules! graphtensor_binop { type Output = GraphTensor; /// Add an elementwise operation to the graph. fn $fn_name(self, rhs: Self) -> Self::Output { - self.graph.write().unwrap().add_op(Op::BinaryOp { + self.graph.write().unwrap().add_op::(Op::BinaryOp { l_id: self.id(), r_id: rhs.id(), operator: BinaryOpType::$trait, @@ -131,7 +187,7 @@ impl, D: Dev> Neg for GraphTensor type Output = GraphTensor; /// Add an elementwise addition operation to the graph. fn neg(self) -> Self::Output { - self.graph.write().unwrap().add_op(Op::UnaryOp { + self.graph.write().unwrap().add_op::(Op::UnaryOp { v_id: self.id(), operator: UnaryOpType::Neg, }); diff --git a/constensor-core/tests/ops.rs b/constensor-core/tests/ops.rs index 41bfce0..deba8ec 100644 --- a/constensor-core/tests/ops.rs +++ b/constensor-core/tests/ops.rs @@ -1,6 +1,6 @@ #[cfg(feature = "cuda")] use constensor_core::Cuda; -use constensor_core::{Cpu, Graph, GraphTensor, R1, R2}; +use constensor_core::{Cpu, Graph, GraphTensor, R1, R2, R3}; #[cfg(feature = "bfloat")] use half::bf16; #[cfg(feature = "half")] @@ -53,6 +53,29 @@ macro_rules! test_for_device_float { let tensor = res.to_tensor().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![1.0, 1.25, 1.5, 1.75]); } + + #[test] + fn matmul() { + let mut graph = Graph::empty(); + let a = GraphTensor::, f32, $dev>::ones(&mut graph); + let b = GraphTensor::, f32, $dev>::ones(&mut graph); + let c = a.matmul(b); + let tensor = c.to_tensor().unwrap(); + let expected: [Vec<[f32; 2]>; 1] = [vec![[3.0, 3.0], [3.0, 3.0]]]; + assert_eq!(tensor.data().unwrap().to_vec(), expected); + } + + #[test] + fn matmul_axpby() { + let mut graph = Graph::empty(); + let a = GraphTensor::, f32, $dev>::ones(&mut graph); + let b = GraphTensor::, f32, $dev>::ones(&mut graph); + let o = GraphTensor::, f32, $dev>::ones(&mut graph); + let c = a.matmul_axpby(b, o, 1., 1.); + let tensor = c.to_tensor().unwrap(); + let expected: [Vec<[f32; 2]>; 1] = [vec![[4.0, 4.0], [4.0, 4.0]]]; + assert_eq!(tensor.data().unwrap().to_vec(), expected); + } } }; } @@ -100,6 +123,29 @@ macro_rules! test_for_device_int { let tensor = res.to_tensor().unwrap(); assert_eq!(tensor.data().unwrap().to_vec(), vec![1, 2, 3, 4]); } + + #[test] + fn matmul() { + let mut graph = Graph::empty(); + let a = GraphTensor::, i32, $dev>::ones(&mut graph); + let b = GraphTensor::, i32, $dev>::ones(&mut graph); + let c = a.matmul(b); + let tensor = c.to_tensor().unwrap(); + let expected: [Vec<[i32; 2]>; 1] = [vec![[3, 3], [3, 3]]]; + assert_eq!(tensor.data().unwrap().to_vec(), expected); + } + + #[test] + fn matmul_axpby() { + let mut graph = Graph::empty(); + let a = GraphTensor::, i32, $dev>::ones(&mut graph); + let b = GraphTensor::, i32, $dev>::ones(&mut graph); + let o = GraphTensor::, i32, $dev>::ones(&mut graph); + let c = a.matmul_axpby(b, o, 1, 1); + let tensor = c.to_tensor().unwrap(); + let expected: [Vec<[i32; 2]>; 1] = [vec![[4, 4], [4, 4]]]; + assert_eq!(tensor.data().unwrap().to_vec(), expected); + } } }; } diff --git a/graph.png b/graph.png index 8a2850d..6f0d711 100644 Binary files a/graph.png and b/graph.png differ