diff --git a/README.md b/README.md index 45032cc..465ab81 100644 --- a/README.md +++ b/README.md @@ -10,13 +10,15 @@ ML framework featuring compile time checks and accelerated by a JIT compiler.

-Constensor is a fast alternative to Candle which provides the following key features: +Constensor is a fast ML framework which provides the following key features: + - **Compile time shape, dtype, and device checking**: Develop quickly and handle common errors - **Opt-in half precision support**: Run on any GPU -- **Elementwise JIT kernel fusion**: Accelerate CUDA kernels automatically by fusing binary and unary operations - - Fuse binary operations into one kernel - - Use device specific operations such as [`fma`](https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__DOUBLE.html#group__CUDA__MATH__DOUBLE_1gff2117f6f3c4ff8a2aa4ce48a0ff2070) to accelerate. -- **Automatic inplacing**: Avoid duplicate allocations +- **Advanced AI compiler features:** + - Elementwise JIT kernel fusion + - Automatic inplacing + - Constant folding + - Dead code removal ```rust diff --git a/constensor-core/src/graph.rs b/constensor-core/src/graph.rs index 152fc75..6a3ba25 100644 --- a/constensor-core/src/graph.rs +++ b/constensor-core/src/graph.rs @@ -207,6 +207,50 @@ impl Graph { Ok(()) } + /// Optimize by performing constant folding: + /// - Fold BinaryOp and UnaryOp when all operands are constant Fill ops. + fn optimize_const(&mut self) { + // Clone current ops for inspection + let ops = self.data.read().unwrap().clone(); + let mut new_ops = ops.clone(); + for (i, node) in ops.iter().enumerate() { + match &node.op { + Op::BinaryOp { + l_id, + r_id, + operator, + } => { + let l_idx = l_id.get(); + let r_idx = r_id.get(); + // both operands are constant fills + if let Op::Fill { v: v1 } = &new_ops[l_idx].op { + if let Op::Fill { v: v2 } = &new_ops[r_idx].op { + let v = operator.as_closure()(*v1, *v2); + new_ops[i] = GraphNode { + op: Op::Fill { v }, + shape: node.shape.clone(), + }; + } + } + } + Op::UnaryOp { v_id, operator } => { + let idx = v_id.get(); + // operand is a constant fill + if let Op::Fill { v: v0 } = &new_ops[idx].op { + let v = operator.to_closure()(*v0); + new_ops[i] = GraphNode { + op: Op::Fill { v }, + shape: node.shape.clone(), + }; + } + } + _ => {} + } + } + // Commit folded constants + *self.data.write().unwrap() = new_ops; + } + /// Optimize by looking for mul-add pairs, convert to FMA fn optimize_fma(&mut self) { let ops = self.data.write().unwrap().clone(); @@ -432,15 +476,118 @@ impl Graph { *self.data.write().unwrap() = new_ops; } + /// Remove nodes whose outputs are never used, except the final output node. + fn optimize_dead_code(&mut self) { + // Clone current ops + let old_ops = self.data.read().unwrap().clone(); + let n = old_ops.len(); + // Mark reachable nodes: start from final output + let mut keep = vec![false; n]; + if n > 0 { + keep[n - 1] = true; + } + // Propagate reachability backwards + for i in (0..n).rev() { + if keep[i] { + match &old_ops[i].op { + Op::BinaryOp { l_id, r_id, .. } => { + keep[l_id.get()] = true; + keep[r_id.get()] = true; + } + Op::UnaryOp { v_id, .. } => { + keep[v_id.get()] = true; + } + Op::FusedMulAdd { + a_id, b_id, c_id, .. + } => { + keep[a_id.get()] = true; + keep[b_id.get()] = true; + keep[c_id.get()] = true; + } + Op::MatMul { + l_id, r_id, o_id, .. + } => { + keep[l_id.get()] = true; + keep[r_id.get()] = true; + if let Some(o_id) = o_id { + keep[o_id.get()] = true; + } + } + _ => {} + } + } + } + // Build new ops and map old indices to new indices + let mut index_map = std::collections::HashMap::new(); + let mut new_ops = Vec::new(); + for (old_idx, node) in old_ops.into_iter().enumerate() { + if keep[old_idx] { + let new_idx = new_ops.len(); + index_map.insert(old_idx, new_idx); + new_ops.push(node); + } + } + // Update tensor IDs in remaining ops + for node in new_ops.iter_mut() { + match &mut node.op { + Op::BinaryOp { l_id, r_id, .. } => { + let old_l = l_id.get(); + let old_r = r_id.get(); + l_id.set(*index_map.get(&old_l).unwrap()); + r_id.set(*index_map.get(&old_r).unwrap()); + } + Op::UnaryOp { v_id, .. } => { + let old_v = v_id.get(); + v_id.set(*index_map.get(&old_v).unwrap()); + } + Op::FusedMulAdd { + a_id, b_id, c_id, .. + } => { + let old_a = a_id.get(); + let old_b = b_id.get(); + let old_c = c_id.get(); + a_id.set(*index_map.get(&old_a).unwrap()); + b_id.set(*index_map.get(&old_b).unwrap()); + c_id.set(*index_map.get(&old_c).unwrap()); + } + Op::MatMul { + l_id, r_id, o_id, .. + } => { + let old_l = l_id.get(); + let old_r = r_id.get(); + l_id.set(*index_map.get(&old_l).unwrap()); + r_id.set(*index_map.get(&old_r).unwrap()); + if let Some(o_id) = o_id { + let old_o = o_id.get(); + o_id.set(*index_map.get(&old_o).unwrap()); + } + } + _ => {} + } + } + // Commit pruned graph + *self.data.write().unwrap() = new_ops; + } + /// Optimize this graph. /// - /// Apply the following optimizations - /// - Fuse mul,add + /// Apply the following optimizations: + /// - Constant folding of elementwise fills + /// - Fuse mul-add into FMA + /// - Inplace binary operations when safe + /// - Inplace fused multiply-add when safe + /// - Inplace matrix-multiplication when safe + /// - Dead code removal pub fn optimize(&mut self) { + // Constant folding first + self.optimize_const(); + // Fuse mul-add into FMA self.optimize_fma(); self.optimize_inplace_bin(); self.optimize_inplace_fma(); self.optimize_inplace_matmul(); + // Remove dead code + self.optimize_dead_code(); } pub fn compile(self) -> Result> { diff --git a/graph.png b/graph.png index 6f0d711..584fddd 100644 Binary files a/graph.png and b/graph.png differ