Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions constensor-core/examples/matmul/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use constensor_core::{CompiledGraph, Cpu, DType, Graph, GraphTensor, R3};
use constensor_core::{BestDevice, CompiledGraph, DType, Graph, GraphTensor, R3};
use std::time::Instant;

fn bench<T: DType, const B: usize, const M: usize, const K: usize, const N: usize>(
Expand All @@ -7,22 +7,25 @@ fn bench<T: DType, const B: usize, const M: usize, const K: usize, const N: usiz
beta: T,
) {
// Number of times to run the matmul for averaging
let iterations = 1000;
let iterations = 1;
let mut total = std::time::Duration::new(0, 0);

let mut graph = Graph::empty();
let a = GraphTensor::<R3<B, M, K>, T, Cpu>::ones(&mut graph);
let b = GraphTensor::<R3<B, K, N>, T, Cpu>::ones(&mut graph);
let o = GraphTensor::<R3<B, M, N>, T, Cpu>::ones(&mut graph);
let a = GraphTensor::<R3<B, M, K>, T, BestDevice<0>>::fill(&mut graph, T::from_f64(1.));
// Strided matmuls works on all devices.
let b = GraphTensor::<R3<B, N, K>, T, BestDevice<0>>::fill(&mut graph, T::from_f64(2.)).t();
// let b = GraphTensor::<R3<B, K, N>, T, BestDevice<0>>::fill(&mut graph, T::from_f64(2.));
let o = GraphTensor::<R3<B, M, N>, T, BestDevice<0>>::fill(&mut graph, T::from_f64(3.));
let _c = a.matmul_axpby(b, o, alpha, beta);

graph.optimize();
let compiled: CompiledGraph<R3<B, M, N>, T, Cpu> = graph.compile().unwrap();
let compiled: CompiledGraph<R3<B, M, N>, T, BestDevice<0>> = graph.compile().unwrap();

for _ in 0..iterations {
let start = Instant::now();

let _tensor = std::hint::black_box(compiled.run().unwrap());
let tensor = std::hint::black_box(compiled.run().unwrap());
dbg!(tensor.data().unwrap());

total += start.elapsed();
}
Expand All @@ -33,10 +36,10 @@ fn bench<T: DType, const B: usize, const M: usize, const K: usize, const N: usiz

fn main() {
const B: usize = 1;
const M: usize = 128;
const N: usize = 128;
const K: usize = 128;
const M: usize = 2;
const N: usize = 2;
const K: usize = 2;

bench::<f32, B, M, K, N>("f32", 1.0, 1.0);
bench::<i32, B, M, K, N>("i32", 1, 1);
// bench::<i32, B, M, K, N>("i32", 1, 1);
}
67 changes: 58 additions & 9 deletions constensor-core/src/cpu_storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelI

use crate::device::Dev;
use crate::storage::Storage;
use crate::tensor::contiguous_strides;
use crate::Shape;
use crate::{
storage::{BackendDevice, BackendStorage},
Expand Down Expand Up @@ -44,11 +45,12 @@ impl BackendDevice for CpuDevice {
) -> Result<CompiledGraph<S, T, D>> {
// Build a dependency graph of tensor indices
let mut dep_graph = DiGraphMap::<usize, ()>::new();
for idx in 0..graph.len() {
dep_graph.add_node(idx);
for id in graph.iter().map(|node| node.id.get()) {
dep_graph.add_node(id);
}

for (idx, node) in graph.iter().enumerate() {
for node in graph.iter() {
let idx = node.id.get();
match &node.op {
Op::BinaryOp { l_id, r_id, .. } => {
dep_graph.add_edge(l_id.get(), idx, ());
Expand All @@ -62,9 +64,17 @@ impl BackendDevice for CpuDevice {
dep_graph.add_edge(b_id.get(), idx, ());
dep_graph.add_edge(c_id.get(), idx, ());
}
Op::MatMul { l_id, r_id, .. } => {
Op::MatMul {
l_id, r_id, o_id, ..
} => {
dep_graph.add_edge(l_id.get(), idx, ());
dep_graph.add_edge(r_id.get(), idx, ());
if let Some(o_id) = o_id {
dep_graph.add_edge(o_id.get(), idx, ());
}
}
Op::Permute { v_id } => {
dep_graph.add_edge(v_id.get(), idx, ());
}
// NoOp, Fill/Arange, Rand/Randn don’t create incoming edges
Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {}
Expand Down Expand Up @@ -102,6 +112,8 @@ impl BackendDevice for CpuDevice {
// Prepare storage for intermediate results
let mut results: Vec<Option<PooledBuffer<T>>> = Vec::with_capacity(graph.len());
results.resize_with(graph.len(), || None);
let mut results_strides: Vec<Option<Vec<usize>>> = Vec::with_capacity(graph.len());
results_strides.resize_with(graph.len(), || None);

let mut rng = rng();

Expand Down Expand Up @@ -224,27 +236,64 @@ impl BackendDevice for CpuDevice {
let m = shape[1];
let n = shape[2];

let mut out = if let Some(o_id) = o_id {
let (mut out, out_stride) = if let Some(o_id) = o_id {
if o_id.is_inplace() {
results[o_id.get()].take().unwrap()
let out_strides = results_strides[o_id.get()].as_ref().unwrap();
(results[o_id.get()].take().unwrap(), out_strides.clone())
} else {
let o_buf = results[o_id.get()].as_ref().unwrap();
PooledBuffer::new((*o_buf).clone(), pool.clone())
let out_strides = results_strides[o_id.get()].as_ref().unwrap();
(
PooledBuffer::new((*o_buf).clone(), pool.clone()),
out_strides.clone(),
)
}
} else {
PooledBuffer::new(pool.borrow_mut().get_buffer(m * n), pool.clone())
(
PooledBuffer::new(
pool.borrow_mut().get_buffer(b * m * n),
pool.clone(),
),
contiguous_strides(&[b, m, n]),
)
};

let a_buf = results[l_id.get()].as_ref().unwrap();
let b_buf = results[r_id.get()].as_ref().unwrap();

T::launch_gemm(a_buf, b_buf, b, m, n, *k, &mut out, *alpha, *beta);
let a_strides = results_strides[l_id.get()].as_ref().unwrap();
let b_strides = results_strides[r_id.get()].as_ref().unwrap();

T::launch_gemm(
a_buf,
a_strides,
b_buf,
b_strides,
b,
m,
n,
*k,
&mut out,
&out_stride,
*alpha,
*beta,
);

out
}
Op::NoOp => unreachable!("NoOp should not be evaluated."),
Op::Permute { v_id } => {
if v_id.is_inplace() {
results[v_id.get()].take().unwrap()
} else {
let buf = results[v_id.get()].as_ref().unwrap();
PooledBuffer::new((*buf).clone(), pool.clone())
}
}
};

results[idx] = Some(computed);
results_strides[idx] = Some(op.strides.clone());
}

// Extract final result
Expand Down
82 changes: 60 additions & 22 deletions constensor-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::sync::{
};
use std::{
borrow::Cow,
collections::{HashMap, HashSet, VecDeque},
collections::{HashMap, VecDeque},
fs,
hash::{DefaultHasher, Hash, Hasher},
marker::PhantomData,
Expand All @@ -25,6 +25,7 @@ use crate::{
cpu_storage::CpuStorage,
device::Dev,
storage::{BackendDevice, BackendStorage, Storage},
tensor::contiguous_strides,
CompiledGraph, DType, GraphNode, Op, Result, Shape,
};

Expand Down Expand Up @@ -211,6 +212,9 @@ pub enum CudaCompiledKernel<T: DType> {
r_id: usize,
/// Optional output tensor ID for axpby semantics
o_id: Option<usize>,
l_stride: Vec<usize>,
r_stride: Vec<usize>,
o_stride: Option<Vec<usize>>,
b: usize,
m: usize,
n: usize,
Expand Down Expand Up @@ -304,6 +308,10 @@ fn handle_node<T: DType>(
format!("( static_cast<T>(fma(static_cast<double>({a_name}), static_cast<double>({b_name}), static_cast<double>({c_name}))))")
}
Op::NoOp => unreachable!("no-op ops should never be reached."),
Op::Permute { v_id } => {
let name = handle_node(current_name, header, &graph[v_id.get()], graph);
format!("({})", name)
}
Op::MatMul { .. } | Op::Rand | Op::Randn { .. } => {
unreachable!("op should have its own split!")
}
Expand Down Expand Up @@ -468,7 +476,7 @@ impl BackendDevice for CudaDevice {
) -> Result<CompiledGraph<S, T, D>> {
// Build a dependency graph of tensor indices
let mut dep_graph = DiGraphMap::<usize, ()>::new();
for idx in 0..graph.len() {
for idx in graph.iter().map(|node| node.id.get()) {
dep_graph.add_node(idx);
}

Expand All @@ -486,9 +494,17 @@ impl BackendDevice for CudaDevice {
dep_graph.add_edge(b_id.get(), idx, ());
dep_graph.add_edge(c_id.get(), idx, ());
}
Op::MatMul { l_id, r_id, .. } => {
Op::MatMul {
l_id, r_id, o_id, ..
} => {
dep_graph.add_edge(l_id.get(), idx, ());
dep_graph.add_edge(r_id.get(), idx, ());
if let Some(o_id) = o_id {
dep_graph.add_edge(o_id.get(), idx, ());
}
}
Op::Permute { v_id } => {
dep_graph.add_edge(v_id.get(), idx, ());
}
// These don’t create incoming edges
Op::NoOp | Op::Fill { .. } | Op::Rand | Op::Randn { .. } | Op::Arange { .. } => {}
Expand All @@ -502,20 +518,6 @@ impl BackendDevice for CudaDevice {
let mut kernels = Vec::<CudaCompiledKernel<T>>::new();
let mut matmuls = Vec::<CudaCompiledKernel<T>>::new();
let mut splits: Vec<(Vec<usize>, Vec<usize>)> = Vec::new();
// Collect all matmul input node indices
let mut matmul_inputs = HashSet::new();
for &idx in &order {
if let Op::MatMul {
l_id, r_id, o_id, ..
} = &graph[idx].op
{
matmul_inputs.insert(l_id.get());
matmul_inputs.insert(r_id.get());
if let Some(o_id) = o_id {
matmul_inputs.insert(o_id.get());
}
}
}

for &idx in &order {
match &graph[idx].op {
Expand All @@ -529,8 +531,12 @@ impl BackendDevice for CudaDevice {
} => {
let l_shape = &graph[l_id.get()].shape;
let r_shape = &graph[r_id.get()].shape;
let l_stride = &graph[l_id.get()].strides;
let r_stride = &graph[r_id.get()].strides;
assert_eq!(l_shape.len(), 3);
assert_eq!(r_shape.len(), 3);
assert_eq!(l_stride.len(), 3);
assert_eq!(r_stride.len(), 3);
let (b, m, _k) = (l_shape[0], l_shape[1], l_shape[2]);
let n = r_shape[2];

Expand All @@ -542,6 +548,9 @@ impl BackendDevice for CudaDevice {
l_id: l_id.get(),
r_id: r_id.get(),
o_id: o_id.as_ref().map(|id| id.get()),
l_stride: l_stride.clone(),
r_stride: r_stride.clone(),
o_stride: o_id.as_ref().map(|id| graph[id.get()].strides.clone()),
b,
m,
n,
Expand Down Expand Up @@ -583,11 +592,32 @@ impl BackendDevice for CudaDevice {
}
_ => {
let shape_key = graph[idx].shape.clone();
// Group only when same shape and this op depends on the last split node
let should_group = if let Some((last_group, _)) = splits.last_mut() {
let last_idx = *last_group.last().unwrap();
let last_shape_key = graph[last_idx].shape.clone();
// Force all matmul inputs to have their own
last_shape_key == shape_key && !matmul_inputs.contains(&idx)
if graph[last_idx].shape == shape_key {
match &graph[idx].op {
Op::BinaryOp { l_id, r_id, .. } => {
l_id.get() == last_idx || r_id.get() == last_idx
}
Op::UnaryOp { v_id, .. } => v_id.get() == last_idx,
Op::FusedMulAdd { a_id, b_id, c_id } => {
a_id.get() == last_idx
|| b_id.get() == last_idx
|| c_id.get() == last_idx
}
Op::Permute { v_id } => v_id.get() == last_idx,
// Init ops always start new group
Op::NoOp
| Op::Fill { .. }
| Op::Arange { .. }
| Op::Rand
| Op::Randn { .. }
| Op::MatMul { .. } => false,
}
} else {
false
}
} else {
false
};
Expand Down Expand Up @@ -654,6 +684,9 @@ impl BackendDevice for CudaDevice {
l_id,
r_id,
o_id,
l_stride,
r_stride,
o_stride,
b,
m,
n,
Expand All @@ -672,7 +705,7 @@ impl BackendDevice for CudaDevice {
lhs.event.synchronize().w()?;
rhs.event.synchronize().w()?;

let elems = m * n;
let elems = b * m * n;
// prepare output buffer, copy initial if provided
let mut out = unsafe { stream.alloc::<T>(elems) }.w()?;
if let Some(o_idx) = o_id {
Expand All @@ -682,9 +715,14 @@ impl BackendDevice for CudaDevice {
self.stream().memcpy_dtod(&init.slice, &mut out).w()?;
}

let o_stride = o_stride
.clone()
.unwrap_or(contiguous_strides(&[*b, *m, *n]));

// Launch GEMM on the pooled stream
T::launch_gemm_cuda(
cublas, &lhs.slice, &rhs.slice, *b, *m, *n, *k, &mut out, *beta, *alpha,
cublas, &lhs.slice, &rhs.slice, l_stride, r_stride, *b, *m, *n, *k,
&mut out, &o_stride, *beta, *alpha,
)?;

// Record completion event for the MatMul result
Expand Down
Loading
Loading