diff --git a/README.md b/README.md
index 45032cc..465ab81 100644
--- a/README.md
+++ b/README.md
@@ -10,13 +10,15 @@ ML framework featuring compile time checks and accelerated by a JIT compiler.
-Constensor is a fast alternative to Candle which provides the following key features:
+Constensor is a fast ML framework which provides the following key features:
+
- **Compile time shape, dtype, and device checking**: Develop quickly and handle common errors
- **Opt-in half precision support**: Run on any GPU
-- **Elementwise JIT kernel fusion**: Accelerate CUDA kernels automatically by fusing binary and unary operations
- - Fuse binary operations into one kernel
- - Use device specific operations such as [`fma`](https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__DOUBLE.html#group__CUDA__MATH__DOUBLE_1gff2117f6f3c4ff8a2aa4ce48a0ff2070) to accelerate.
-- **Automatic inplacing**: Avoid duplicate allocations
+- **Advanced AI compiler features:**
+ - Elementwise JIT kernel fusion
+ - Automatic inplacing
+ - Constant folding
+ - Dead code removal
```rust
diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs
index 152fc75..6a3ba25 100644
--- a/constensor-core/src/graph.rs
+++ b/constensor-core/src/graph.rs
@@ -207,6 +207,50 @@ impl Graph {
Ok(())
}
+ /// Optimize by performing constant folding:
+ /// - Fold BinaryOp and UnaryOp when all operands are constant Fill ops.
+ fn optimize_const(&mut self) {
+ // Clone current ops for inspection
+ let ops = self.data.read().unwrap().clone();
+ let mut new_ops = ops.clone();
+ for (i, node) in ops.iter().enumerate() {
+ match &node.op {
+ Op::BinaryOp {
+ l_id,
+ r_id,
+ operator,
+ } => {
+ let l_idx = l_id.get();
+ let r_idx = r_id.get();
+ // both operands are constant fills
+ if let Op::Fill { v: v1 } = &new_ops[l_idx].op {
+ if let Op::Fill { v: v2 } = &new_ops[r_idx].op {
+ let v = operator.as_closure()(*v1, *v2);
+ new_ops[i] = GraphNode {
+ op: Op::Fill { v },
+ shape: node.shape.clone(),
+ };
+ }
+ }
+ }
+ Op::UnaryOp { v_id, operator } => {
+ let idx = v_id.get();
+ // operand is a constant fill
+ if let Op::Fill { v: v0 } = &new_ops[idx].op {
+ let v = operator.to_closure()(*v0);
+ new_ops[i] = GraphNode {
+ op: Op::Fill { v },
+ shape: node.shape.clone(),
+ };
+ }
+ }
+ _ => {}
+ }
+ }
+ // Commit folded constants
+ *self.data.write().unwrap() = new_ops;
+ }
+
/// Optimize by looking for mul-add pairs, convert to FMA
fn optimize_fma(&mut self) {
let ops = self.data.write().unwrap().clone();
@@ -432,15 +476,118 @@ impl Graph {
*self.data.write().unwrap() = new_ops;
}
+ /// Remove nodes whose outputs are never used, except the final output node.
+ fn optimize_dead_code(&mut self) {
+ // Clone current ops
+ let old_ops = self.data.read().unwrap().clone();
+ let n = old_ops.len();
+ // Mark reachable nodes: start from final output
+ let mut keep = vec![false; n];
+ if n > 0 {
+ keep[n - 1] = true;
+ }
+ // Propagate reachability backwards
+ for i in (0..n).rev() {
+ if keep[i] {
+ match &old_ops[i].op {
+ Op::BinaryOp { l_id, r_id, .. } => {
+ keep[l_id.get()] = true;
+ keep[r_id.get()] = true;
+ }
+ Op::UnaryOp { v_id, .. } => {
+ keep[v_id.get()] = true;
+ }
+ Op::FusedMulAdd {
+ a_id, b_id, c_id, ..
+ } => {
+ keep[a_id.get()] = true;
+ keep[b_id.get()] = true;
+ keep[c_id.get()] = true;
+ }
+ Op::MatMul {
+ l_id, r_id, o_id, ..
+ } => {
+ keep[l_id.get()] = true;
+ keep[r_id.get()] = true;
+ if let Some(o_id) = o_id {
+ keep[o_id.get()] = true;
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ // Build new ops and map old indices to new indices
+ let mut index_map = std::collections::HashMap::new();
+ let mut new_ops = Vec::new();
+ for (old_idx, node) in old_ops.into_iter().enumerate() {
+ if keep[old_idx] {
+ let new_idx = new_ops.len();
+ index_map.insert(old_idx, new_idx);
+ new_ops.push(node);
+ }
+ }
+ // Update tensor IDs in remaining ops
+ for node in new_ops.iter_mut() {
+ match &mut node.op {
+ Op::BinaryOp { l_id, r_id, .. } => {
+ let old_l = l_id.get();
+ let old_r = r_id.get();
+ l_id.set(*index_map.get(&old_l).unwrap());
+ r_id.set(*index_map.get(&old_r).unwrap());
+ }
+ Op::UnaryOp { v_id, .. } => {
+ let old_v = v_id.get();
+ v_id.set(*index_map.get(&old_v).unwrap());
+ }
+ Op::FusedMulAdd {
+ a_id, b_id, c_id, ..
+ } => {
+ let old_a = a_id.get();
+ let old_b = b_id.get();
+ let old_c = c_id.get();
+ a_id.set(*index_map.get(&old_a).unwrap());
+ b_id.set(*index_map.get(&old_b).unwrap());
+ c_id.set(*index_map.get(&old_c).unwrap());
+ }
+ Op::MatMul {
+ l_id, r_id, o_id, ..
+ } => {
+ let old_l = l_id.get();
+ let old_r = r_id.get();
+ l_id.set(*index_map.get(&old_l).unwrap());
+ r_id.set(*index_map.get(&old_r).unwrap());
+ if let Some(o_id) = o_id {
+ let old_o = o_id.get();
+ o_id.set(*index_map.get(&old_o).unwrap());
+ }
+ }
+ _ => {}
+ }
+ }
+ // Commit pruned graph
+ *self.data.write().unwrap() = new_ops;
+ }
+
/// Optimize this graph.
///
- /// Apply the following optimizations
- /// - Fuse mul,add
+ /// Apply the following optimizations:
+ /// - Constant folding of elementwise fills
+ /// - Fuse mul-add into FMA
+ /// - Inplace binary operations when safe
+ /// - Inplace fused multiply-add when safe
+ /// - Inplace matrix-multiplication when safe
+ /// - Dead code removal
pub fn optimize(&mut self) {
+ // Constant folding first
+ self.optimize_const();
+ // Fuse mul-add into FMA
self.optimize_fma();
self.optimize_inplace_bin();
self.optimize_inplace_fma();
self.optimize_inplace_matmul();
+ // Remove dead code
+ self.optimize_dead_code();
}
pub fn compile(self) -> Result> {
diff --git a/graph.png b/graph.png
index 6f0d711..584fddd 100644
Binary files a/graph.png and b/graph.png differ