Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,927 changes: 1,435 additions & 492 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ gemm = "0.18"
num_cpus = "1.16.0"
rand = "0.9.1"
rand_distr = "0.5.1"
cubecl = "0.6.0"
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ Experimental machine learning framework featuring a graph-based JIT compiler.
- Multi-device support (CPU, optional CUDA)
- Graph visualization (requires Graphviz)
- Zero-cost abstractions with idiomatic Rust API
- Broad GPU support:
- CUDA (NVIDIA): `--features cuda`
- HIP (AMD): `--features hip`
- Metal (Apple): `--features metal`
- WGPU (recommended for [all others](https://github.com/gfx-rs/wgpu?tab=readme-ov-file#supported-platforms)): `--features wgpu`

## Installation

Expand Down
13 changes: 5 additions & 8 deletions constensor-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,19 @@ gemm.workspace = true
num_cpus.workspace = true
rand.workspace = true
rand_distr.workspace = true
cubecl = { workspace = true, features = ["compilation-cache", "wgpu"] }

[features]
default = []
all = ["cuda", "half", "bfloat"]
cuda = ["cudarc"]
default = ["half", "bfloat"]
cuda = ["cudarc", "cubecl/cuda"]
hip = ["cubecl/hip"]
metal = ["cubecl/wgpu-msl"]
half = ["dep:half"]
bfloat = ["dep:half"]
slow_integral_fma_cuda = []

[[example]]
name = "hello_world"
required-features = []

[dev-dependencies]
criterion = "0.5"
candle-core = "0.8"

[[bench]]
name = "cpu_graph"
Expand Down
68 changes: 34 additions & 34 deletions constensor-core/benches/cpu_graph.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle_core::{Device, Tensor};
// use candle_core::{Device, Tensor};
use constensor_core::{Cpu, Graph, GraphTensor, R3};
use criterion::{criterion_group, criterion_main, Criterion};

Expand Down Expand Up @@ -44,46 +44,46 @@ fn bench_cpu_graph_matmul_256(c: &mut Criterion) {
});
}

fn bench_candle_matmul_64(c: &mut Criterion) {
const N: usize = 64;
c.bench_function("candle_matmul_64x64", |bencher| {
bencher.iter(|| {
let a = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
let b = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
let _ = a.matmul(&b).unwrap();
});
});
}
// fn bench_candle_matmul_64(c: &mut Criterion) {
// const N: usize = 64;
// let a = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
// let b = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
// c.bench_function("candle_matmul_64x64", |bencher| {
// bencher.iter(|| {
// let _ = a.matmul(&b).unwrap();
// });
// });
// }

fn bench_candle_matmul_128(c: &mut Criterion) {
const N: usize = 128;
c.bench_function("candle_matmul_128x128", |bencher| {
bencher.iter(|| {
let a = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
let b = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
let _ = a.matmul(&b).unwrap();
});
});
}
// fn bench_candle_matmul_128(c: &mut Criterion) {
// const N: usize = 128;
// let a = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
// let b = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
// c.bench_function("candle_matmul_128x128", |bencher| {
// bencher.iter(|| {
// let _ = a.matmul(&b).unwrap();
// });
// });
// }

fn bench_candle_matmul_256(c: &mut Criterion) {
const N: usize = 256;
c.bench_function("candle_matmul_256x256", |bencher| {
bencher.iter(|| {
let a = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
let b = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
let _ = a.matmul(&b).unwrap();
});
});
}
// fn bench_candle_matmul_256(c: &mut Criterion) {
// const N: usize = 256;
// let a = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
// let b = Tensor::rand(0f32, 1f32, &[1, N, N], &Device::Cpu).unwrap();
// c.bench_function("candle_matmul_256x256", |bencher| {
// bencher.iter(|| {
// let _ = a.matmul(&b).unwrap();
// });
// });
// }

criterion_group!(
benches,
bench_cpu_graph_matmul_64,
bench_cpu_graph_matmul_128,
bench_cpu_graph_matmul_256,
bench_candle_matmul_64,
bench_candle_matmul_128,
bench_candle_matmul_256
// bench_candle_matmul_64,
// bench_candle_matmul_128,
// bench_candle_matmul_256
);
criterion_main!(benches);
18 changes: 18 additions & 0 deletions constensor-core/examples/test/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use constensor_core::{Graph, GraphTensor, Tensor, Wgpu, R2};

fn main() {
let mut graph: Graph<f32> = Graph::empty();
let a = GraphTensor::<R2<3, 4>, f32, Wgpu>::fill(&mut graph, 1.0);
let b = GraphTensor::<R2<3, 4>, f32, Wgpu>::fill(&mut graph, 2.0);
let c = GraphTensor::<R2<3, 4>, f32, Wgpu>::fill(&mut graph, 3.0);
let _out = a + b + c;

graph.visualize("graph.png").unwrap();

let compiled: constensor_core::CompiledGraph<R2<3, 4>, f32, Wgpu> = graph.compile().unwrap();
let res = compiled.run().unwrap();

let tensor: Tensor<R2<3, 4>, f32, Wgpu> = res;

assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![6.0; 4]; 3],);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use petgraph::algo::toposort;
use petgraph::graphmap::DiGraphMap;
use super::scheduler::topo_order;
use std::{borrow::Cow, marker::PhantomData};

use pool::{BufferPool, PooledBuffer};
Expand Down Expand Up @@ -31,7 +30,7 @@ impl<T: DType> BackendStorage<T> for CpuStorage<T> {
Ok(Cow::Borrowed(self))
}
fn cast<U: DType>(&self) -> Result<Storage<U>> {
let new = self.0.iter().map(|x| U::from_f64(x.to_f64()));
let new = self.0.iter().map(|x| U::from_f64(x.cast_f64()));
Ok(Storage::Cpu(CpuStorage(new.collect())))
}
}
Expand All @@ -43,47 +42,8 @@ impl BackendDevice for CpuDevice {
&self,
graph: Vec<GraphNode<T>>,
) -> Result<CompiledGraph<S, T, D>> {
// Build a dependency graph of tensor indices
let mut dep_graph = DiGraphMap::<usize, ()>::new();
for id in 0..graph.len() {
dep_graph.add_node(id);
}

for node in graph.iter() {
let idx = node.id.get();
match &node.op {
Op::BinaryOp { l_id, r_id, .. } => {
dep_graph.add_edge(l_id.get(), idx, ());
dep_graph.add_edge(r_id.get(), idx, ());
}
Op::UnaryOp { v_id, .. } => {
dep_graph.add_edge(v_id.get(), idx, ());
}
Op::FusedMulAdd { a_id, b_id, c_id } => {
dep_graph.add_edge(a_id.get(), idx, ());
dep_graph.add_edge(b_id.get(), idx, ());
dep_graph.add_edge(c_id.get(), idx, ());
}
Op::MatMul {
l_id, r_id, o_id, ..
} => {
dep_graph.add_edge(l_id.get(), idx, ());
dep_graph.add_edge(r_id.get(), idx, ());
if let Some(o_id) = o_id {
dep_graph.add_edge(o_id.get(), idx, ());
}
}
Op::Permute { v_id } => {
dep_graph.add_edge(v_id.get(), idx, ());
}
// NoOp, Fill/Arange, Rand/Randn don’t create incoming edges
Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {}
}
}

// Compute topological order
let order = toposort(&dep_graph, None).expect("Cycle detected in graph!");

// Compute topological order using shared scheduler
let order = topo_order(&graph);
Ok(CompiledGraph::Cpu {
order,
graph,
Expand Down Expand Up @@ -240,10 +200,10 @@ fn eval_node<T: DType + Send + Sync + 'static>(
}
Op::Arange { start, step, stop } => {
let mut buf = pool.lock().unwrap().get_empty_buffer(out_elem_count);
let mut x = start.to_f64();
while x < stop.to_f64() {
let mut x = start.cast_f64();
while x < stop.cast_f64() {
buf.push(T::from_f64(x));
x += step.to_f64();
x += step.cast_f64();
}
PooledBuffer::new(buf, pool.clone())
}
Expand All @@ -255,8 +215,8 @@ fn eval_node<T: DType + Send + Sync + 'static>(
PooledBuffer::new(buf, pool.clone())
}
Op::Randn { mean, std } => {
let mean_f = mean.to_f64();
let std_f = std.to_f64();
let mean_f = mean.cast_f64();
let std_f = std.cast_f64();
let normal = Normal::new(mean_f, std_f).unwrap();
let mut buf = pool.lock().unwrap().get_buffer(out_elem_count);
for elt in &mut buf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl<T: DType> BufferPool<T> {
self.pool
.iter()
.map(|b| b.capacity() * size_of::<T>())
.sum()
.sum::<usize>()
);

buf
Expand Down Expand Up @@ -161,7 +161,7 @@ impl<T: DType> BufferPool<T> {
self.pool
.iter()
.map(|b| b.capacity() * size_of::<T>())
.sum()
.sum::<usize>()
);

self.trim_excess();
Expand Down Expand Up @@ -195,7 +195,7 @@ impl<T: DType> BufferPool<T> {
self.pool
.iter()
.map(|b| b.capacity() * size_of::<T>())
.sum()
.sum::<usize>()
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::scheduler::topo_order;
use cudarc::{
cublas::CudaBlas,
driver::{
Expand All @@ -6,7 +7,6 @@ use cudarc::{
nvrtc::{CompileOptions, Ptx},
};
use error::WrapErr;
use petgraph::{algo::toposort, prelude::DiGraphMap};
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex, RwLock,
Expand Down Expand Up @@ -474,45 +474,8 @@ impl BackendDevice for CudaDevice {
&self,
graph: Vec<GraphNode<T>>,
) -> Result<CompiledGraph<S, T, D>> {
// Build a dependency graph of tensor indices
let mut dep_graph = DiGraphMap::<usize, ()>::new();
for idx in 0..graph.len() {
dep_graph.add_node(idx);
}

for (idx, node) in graph.iter().enumerate() {
match &node.op {
Op::BinaryOp { l_id, r_id, .. } => {
dep_graph.add_edge(l_id.get(), idx, ());
dep_graph.add_edge(r_id.get(), idx, ());
}
Op::UnaryOp { v_id, .. } => {
dep_graph.add_edge(v_id.get(), idx, ());
}
Op::FusedMulAdd { a_id, b_id, c_id } => {
dep_graph.add_edge(a_id.get(), idx, ());
dep_graph.add_edge(b_id.get(), idx, ());
dep_graph.add_edge(c_id.get(), idx, ());
}
Op::MatMul {
l_id, r_id, o_id, ..
} => {
dep_graph.add_edge(l_id.get(), idx, ());
dep_graph.add_edge(r_id.get(), idx, ());
if let Some(o_id) = o_id {
dep_graph.add_edge(o_id.get(), idx, ());
}
}
Op::Permute { v_id } => {
dep_graph.add_edge(v_id.get(), idx, ());
}
// These don’t create incoming edges
Op::NoOp | Op::Fill { .. } | Op::Rand | Op::Randn { .. } | Op::Arange { .. } => {}
}
}

// Compute topological order
let order = toposort(&dep_graph, None).expect("Cycle detected in graph!");
// Compute topological order using shared scheduler
let order = topo_order(&graph);

// New kernel and grouping logic with matmul input tracking
let mut kernels = Vec::<CudaCompiledKernel<T>>::new();
Expand Down
5 changes: 5 additions & 0 deletions constensor-core/src/backends/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod cpu_backend;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
pub mod scheduler;
pub mod wgpu_backend;
50 changes: 50 additions & 0 deletions constensor-core/src/backends/scheduler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use petgraph::algo::toposort;
use petgraph::graphmap::DiGraphMap;

use crate::{DType, GraphNode, Op};

/// Compute a topological ordering of the computation graph nodes.
///
/// # Panics
/// Panics if the graph contains a cycle.
pub fn topo_order<T: DType>(graph: &[GraphNode<T>]) -> Vec<usize> {
// Build dependency graph
let mut dep_graph = DiGraphMap::<usize, ()>::new();
for node in graph.iter() {
let idx = node.id.get();
dep_graph.add_node(idx);
}
for node in graph.iter() {
let dst = node.id.get();
match &node.op {
Op::BinaryOp { l_id, r_id, .. } => {
dep_graph.add_edge(l_id.get(), dst, ());
dep_graph.add_edge(r_id.get(), dst, ());
}
Op::UnaryOp { v_id, .. } => {
dep_graph.add_edge(v_id.get(), dst, ());
}
Op::FusedMulAdd { a_id, b_id, c_id } => {
dep_graph.add_edge(a_id.get(), dst, ());
dep_graph.add_edge(b_id.get(), dst, ());
dep_graph.add_edge(c_id.get(), dst, ());
}
Op::MatMul {
l_id, r_id, o_id, ..
} => {
dep_graph.add_edge(l_id.get(), dst, ());
dep_graph.add_edge(r_id.get(), dst, ());
if let Some(o) = o_id {
dep_graph.add_edge(o.get(), dst, ());
}
}
Op::Permute { v_id } => {
dep_graph.add_edge(v_id.get(), dst, ());
}
// No incoming edges for other ops
_ => {}
}
}
// Compute topological order
toposort(&dep_graph, None).expect("Cycle detected in graph!")
}
Loading