Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ ML framework featuring compile time checks and accelerated by a JIT compiler.

</p>

Constensor is a fast alternative to Candle which provides the following key features:
Constensor is a fast ML framework which provides the following key features:

- **Compile time shape, dtype, and device checking**: Develop quickly and handle common errors
- **Opt-in half precision support**: Run on any GPU
- **Elementwise JIT kernel fusion**: Accelerate CUDA kernels automatically by fusing binary and unary operations
- Fuse binary operations into one kernel
- Use device specific operations such as [`fma`](https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__DOUBLE.html#group__CUDA__MATH__DOUBLE_1gff2117f6f3c4ff8a2aa4ce48a0ff2070) to accelerate.
- **Automatic inplacing**: Avoid duplicate allocations
- **Advanced AI compiler features:**
- Elementwise JIT kernel fusion
- Automatic inplacing
- Constant folding
- Dead code removal


```rust
Expand Down
151 changes: 149 additions & 2 deletions constensor-core/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,50 @@ impl<T: DType> Graph<T> {
Ok(())
}

/// Optimize by performing constant folding:
/// - Fold BinaryOp and UnaryOp when all operands are constant Fill ops.
fn optimize_const(&mut self) {
// Clone current ops for inspection
let ops = self.data.read().unwrap().clone();
let mut new_ops = ops.clone();
for (i, node) in ops.iter().enumerate() {
match &node.op {
Op::BinaryOp {
l_id,
r_id,
operator,
} => {
let l_idx = l_id.get();
let r_idx = r_id.get();
// both operands are constant fills
if let Op::Fill { v: v1 } = &new_ops[l_idx].op {
if let Op::Fill { v: v2 } = &new_ops[r_idx].op {
let v = operator.as_closure()(*v1, *v2);
new_ops[i] = GraphNode {
op: Op::Fill { v },
shape: node.shape.clone(),
};
}
}
}
Op::UnaryOp { v_id, operator } => {
let idx = v_id.get();
// operand is a constant fill
if let Op::Fill { v: v0 } = &new_ops[idx].op {
let v = operator.to_closure()(*v0);
new_ops[i] = GraphNode {
op: Op::Fill { v },
shape: node.shape.clone(),
};
}
}
_ => {}
}
}
// Commit folded constants
*self.data.write().unwrap() = new_ops;
}

/// Optimize by looking for mul-add pairs, convert to FMA
fn optimize_fma(&mut self) {
let ops = self.data.write().unwrap().clone();
Expand Down Expand Up @@ -432,15 +476,118 @@ impl<T: DType> Graph<T> {
*self.data.write().unwrap() = new_ops;
}

/// Remove nodes whose outputs are never used, except the final output node.
fn optimize_dead_code(&mut self) {
// Clone current ops
let old_ops = self.data.read().unwrap().clone();
let n = old_ops.len();
// Mark reachable nodes: start from final output
let mut keep = vec![false; n];
if n > 0 {
keep[n - 1] = true;
}
// Propagate reachability backwards
for i in (0..n).rev() {
if keep[i] {
match &old_ops[i].op {
Op::BinaryOp { l_id, r_id, .. } => {
keep[l_id.get()] = true;
keep[r_id.get()] = true;
}
Op::UnaryOp { v_id, .. } => {
keep[v_id.get()] = true;
}
Op::FusedMulAdd {
a_id, b_id, c_id, ..
} => {
keep[a_id.get()] = true;
keep[b_id.get()] = true;
keep[c_id.get()] = true;
}
Op::MatMul {
l_id, r_id, o_id, ..
} => {
keep[l_id.get()] = true;
keep[r_id.get()] = true;
if let Some(o_id) = o_id {
keep[o_id.get()] = true;
}
}
_ => {}
}
}
}
// Build new ops and map old indices to new indices
let mut index_map = std::collections::HashMap::new();
let mut new_ops = Vec::new();
for (old_idx, node) in old_ops.into_iter().enumerate() {
if keep[old_idx] {
let new_idx = new_ops.len();
index_map.insert(old_idx, new_idx);
new_ops.push(node);
}
}
// Update tensor IDs in remaining ops
for node in new_ops.iter_mut() {
match &mut node.op {
Op::BinaryOp { l_id, r_id, .. } => {
let old_l = l_id.get();
let old_r = r_id.get();
l_id.set(*index_map.get(&old_l).unwrap());
r_id.set(*index_map.get(&old_r).unwrap());
}
Op::UnaryOp { v_id, .. } => {
let old_v = v_id.get();
v_id.set(*index_map.get(&old_v).unwrap());
}
Op::FusedMulAdd {
a_id, b_id, c_id, ..
} => {
let old_a = a_id.get();
let old_b = b_id.get();
let old_c = c_id.get();
a_id.set(*index_map.get(&old_a).unwrap());
b_id.set(*index_map.get(&old_b).unwrap());
c_id.set(*index_map.get(&old_c).unwrap());
}
Op::MatMul {
l_id, r_id, o_id, ..
} => {
let old_l = l_id.get();
let old_r = r_id.get();
l_id.set(*index_map.get(&old_l).unwrap());
r_id.set(*index_map.get(&old_r).unwrap());
if let Some(o_id) = o_id {
let old_o = o_id.get();
o_id.set(*index_map.get(&old_o).unwrap());
}
}
_ => {}
}
}
// Commit pruned graph
*self.data.write().unwrap() = new_ops;
}

/// Optimize this graph.
///
/// Apply the following optimizations
/// - Fuse mul,add
/// Apply the following optimizations:
/// - Constant folding of elementwise fills
/// - Fuse mul-add into FMA
/// - Inplace binary operations when safe
/// - Inplace fused multiply-add when safe
/// - Inplace matrix-multiplication when safe
/// - Dead code removal
pub fn optimize(&mut self) {
// Constant folding first
self.optimize_const();
// Fuse mul-add into FMA
self.optimize_fma();
self.optimize_inplace_bin();
self.optimize_inplace_fma();
self.optimize_inplace_matmul();
// Remove dead code
self.optimize_dead_code();
}

pub fn compile<S: Shape, D: Dev>(self) -> Result<CompiledGraph<S, T, D>> {
Expand Down
Binary file modified graph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading