diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 3f3041978..793256d67 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -3,7 +3,7 @@ //! An (example) use of the [dataflow analysis framework](super::dataflow). pub mod value_handle; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ @@ -13,7 +13,7 @@ use hugr_core::{ }, ops::{ constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - OpType, Value, + Value, }, types::{EdgeKind, TypeArg}, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, @@ -21,10 +21,12 @@ use hugr_core::{ use value_handle::ValueHandle; use crate::dataflow::{ - partial_from_const, AbstractValue, AnalysisResults, ConstLoader, ConstLocation, DFContext, - Machine, PartialValue, TailLoopTermination, + partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue, + TailLoopTermination, }; +use crate::dead_code::PreserveNode; use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{find_main, DeadCodeElimPass}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. @@ -88,22 +90,15 @@ impl ConstantFoldPass { }); let results = Machine::new(&hugr).run(ConstFoldContext(hugr), inputs); - let keep_nodes = self.find_needed_nodes(&results); let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i); - let remove_nodes = hugr + let wires_to_break = hugr .nodes() - .filter(|n| !keep_nodes.contains(n)) - .collect::>(); - let wires_to_break = keep_nodes - .into_iter() .flat_map(|n| hugr.node_inputs(n).map(move |ip| (n, ip))) .filter(|(n, ip)| { *n != hugr.root() && matches!(hugr.get_optype(*n).port_kind(*ip), Some(EdgeKind::Value(_))) }) - // Note we COULD filter out (avoid breaking) wires from other nodes that we are keeping. - // This would insert fewer constants, but potentially expose less parallelism. .filter_map(|(n, ip)| { let (src, outp) = hugr.single_linked_output(n, ip).unwrap(); // Avoid breaking edges from existing LoadConstant (we'd only add another) @@ -118,20 +113,42 @@ impl ConstantFoldPass { )) }) .collect::>(); + // Sadly the results immutably borrow the hugr, so we must extract everything we need before mutation + let terminating_tail_loops = hugr + .nodes() + .filter(|n| { + results.tail_loop_terminates(*n) == Some(TailLoopTermination::NeverContinues) + }) + .collect::>(); - for (n, import, v) in wires_to_break { + for (n, inport, v) in wires_to_break { let parent = hugr.get_parent(n).unwrap(); let datatype = v.get_type(); // We could try hash-consing identical Consts, but not ATM let cst = hugr.add_node_with_parent(parent, Const::new(v)); let lcst = hugr.add_node_with_parent(parent, LoadConstant { datatype }); hugr.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0)); - hugr.disconnect(n, import); - hugr.connect(lcst, OutgoingPort::from(0), n, import); - } - for n in remove_nodes { - hugr.remove_node(n); + hugr.disconnect(n, inport); + hugr.connect(lcst, OutgoingPort::from(0), n, inport); } + // Dataflow analysis applies our inputs to the 'main' function if this is a Module, so do the same here + DeadCodeElimPass::default() + .with_entry_points(hugr.get_optype(hugr.root()).is_module().then( + // No main => remove everything, so not much use + || find_main(hugr).unwrap(), + )) + .set_preserve_callback(if self.allow_increase_termination { + Arc::new(|_, _| PreserveNode::CanRemove) + } else { + Arc::new(move |_, n| { + if terminating_tail_loops.contains(&n) { + PreserveNode::RemoveIfAllChildrenCanBeRemoved + } else { + PreserveNode::UseDefault + } + }) + }) + .run(hugr)?; Ok(()) } @@ -140,94 +157,6 @@ impl ConstantFoldPass { self.validation .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) } - - fn find_needed_nodes( - &self, - results: &AnalysisResults, - ) -> HashSet { - let mut needed = HashSet::new(); - let h = results.hugr(); - let mut q = VecDeque::from_iter([h.root()]); - while let Some(n) = q.pop_front() { - if !needed.insert(n) { - continue; - }; - if h.get_optype(n).is_module() { - for ch in h.children(n) { - match h.get_optype(ch) { - OpType::AliasDecl(_) | OpType::AliasDefn(_) => { - // Use of these is done via names, rather than following edges. - // We could track these as well but for now be conservative. - q.push_back(ch); - } - OpType::FuncDefn(f) if f.name == "main" => { - // Dataflow analysis will have applied any inputs the 'main' function, so assume reachable. - q.push_back(ch); - } - _ => (), - } - } - } else if h.get_optype(n).is_cfg() { - for bb in h.children(n) { - //if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates - q.push_back(bb); - } - } else if let Some(inout) = h.get_io(n) { - // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges. - q.extend(inout); // Input also necessary for legality even if unreachable - - if !self.allow_increase_termination { - // Also add on anything that might not terminate (even if results not required - - // if its results are required we'll add it by following dataflow, below.) - for ch in h.children(n) { - if might_diverge(results, ch) { - q.push_back(ch); - } - } - } - } - // Also follow dataflow demand - for (src, op) in h.all_linked_outputs(n) { - let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() { - EdgeKind::Value(_) => { - h.get_optype(src).is_load_constant() - || results - .try_read_wire_concrete::(Wire::new(src, op)) - .is_err() - } - EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true, - EdgeKind::ControlFlow => false, // we always include all children of a CFG above - _ => true, // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst - }; - if needs_predecessor { - q.push_back(src); - } - } - } - needed - } -} - -// "Diverge" aka "never-terminate" -// TODO would be more efficient to compute this bottom-up and cache (dynamic programming) -fn might_diverge(results: &AnalysisResults, n: Node) -> bool { - let op = results.hugr().get_optype(n); - if op.is_cfg() { - // TODO if the CFG has no cycles (that are possible given predicates) - // then we could say it definitely terminates (i.e. return false) - true - } else if op.is_tail_loop() - && results.tail_loop_terminates(n).unwrap() != TailLoopTermination::NeverContinues - { - // If we can even figure out the number of iterations is bounded that would allow returning false. - true - } else { - // Node does not introduce non-termination, but still non-terminates if any of its children does - results - .hugr() - .children(n) - .any(|ch| might_diverge(results, ch)) - } } /// Exhaustively apply constant folding to a HUGR. diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 172d87c26..727a26ec1 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -10,6 +10,8 @@ use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::{OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use crate::find_main; + use super::value_row::ValueRow; use super::{ partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, @@ -83,12 +85,7 @@ impl Machine { // we must find the corresponding Input node. let input_node_parent = match self.0.get_optype(root) { OpType::Module(_) => { - let main = self.0.children(root).find(|n| { - self.0 - .get_optype(*n) - .as_func_defn() - .is_some_and(|f| f.name == "main") - }); + let main = find_main(&self.0); if main.is_none() && in_values.next().is_some() { panic!("Cannot give inputs to module with no 'main'"); } diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs new file mode 100644 index 000000000..039154d45 --- /dev/null +++ b/hugr-passes/src/dead_code.rs @@ -0,0 +1,311 @@ +//! Pass for removing dead code, i.e. that computes values that are then discarded + +use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node}; +use std::fmt::{Debug, Formatter}; +use std::{ + collections::{HashSet, VecDeque}, + sync::Arc, +}; + +use crate::validation::{ValidatePassError, ValidationLevel}; + +/// Configuration for Dead Code Elimination pass +#[derive(Clone, Default)] +pub struct DeadCodeElimPass { + entry_points: Vec, + preserve_callback: Option>, + validation: ValidationLevel, +} + +impl Debug for DeadCodeElimPass { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + // Use "derive Debug" by defining an identical struct without the unprintable fields + + #[allow(unused)] // Rust ignores the derive-Debug in figuring out what's used + #[derive(Debug)] + struct DCEDebug<'a> { + entry_points: &'a Vec, + validation: ValidationLevel, + } + + Debug::fmt( + &DCEDebug { + entry_points: &self.entry_points, + validation: self.validation, + }, + f, + ) + } +} + +pub type PreserveCallback = dyn Fn(&Hugr, Node) -> PreserveNode; + +/// Signal that a node must be preserved even when its result is not used +pub enum PreserveNode { + /// The node must be kept (nodes inside it may be removed) + MustKeep, + /// The node can be removed, even if nodes inside it must be kept - the descendants' + /// [PreserveNode] will be ignored and they will be removed too, so use with care. + CanRemove, + /// The default is that [Cfg], [Call] and [TailLoop] nodes must always be kept, + /// but otherwise like [Self::RemoveIfAllChildrenCanBeRemoved] + /// + /// [Call]: hugr_core::ops::Call + /// [CFG]: hugr_core::ops::CFG + /// [TailLoop]: hugr_core::ops::TailLoop + UseDefault, + /// The node must be kept if (and only if) any of its descendants must be kept + RemoveIfAllChildrenCanBeRemoved, +} + +impl DeadCodeElimPass { + /// Sets the validation level used before and after the pass is run + #[allow(unused)] + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self + } + + /// Allows setting a callback that determines whether a node must be preserved + /// (even when its result is not used) + pub fn set_preserve_callback(mut self, cb: Arc) -> Self { + self.preserve_callback = Some(cb); + self + } + + /// Mark some nodes as entry-points to the Hugr. + /// The root node is assumed to be an entry point; + /// for Module roots the client will want to mark some of the FuncDefn children + /// as entry-points too. + pub fn with_entry_points(mut self, entry_points: impl IntoIterator) -> Self { + self.entry_points.extend(entry_points); + self + } + + fn find_needed_nodes(&self, h: impl HugrView) -> HashSet { + let mut needed = HashSet::new(); + let mut q = VecDeque::from_iter(self.entry_points.iter().cloned()); + q.push_front(h.root()); + while let Some(n) = q.pop_front() { + if !needed.insert(n) { + continue; + }; + for ch in h.children(n) { + if self.must_preserve(&h, ch) + || matches!( + h.get_optype(ch), + OpType::Case(_) // Include all Cases in Conditionals + | OpType::DataflowBlock(_) // and all Basic Blocks in CFGs + | OpType::ExitBlock(_) + | OpType::AliasDecl(_) // and all Aliases (we do not track their uses in types) + | OpType::AliasDefn(_) + | OpType::Input(_) // Also Dataflow input/output, these are necessary for legality + | OpType::Output(_) // Do not include FuncDecl / FuncDefn / Const unless reachable by static edges + // (from Call/LoadConst/LoadFunction): + ) + { + q.push_back(ch); + } + } + // Finally, follow dataflow demand (including e.g. edges from Call to FuncDefn) + for src in h.input_neighbours(n) { + // Following ControlFlow edges backwards is harmless, we've already assumed all + // BBs are reachable above. + q.push_back(src); + } + } + needed + } + + pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + self.validation + .run_validated_pass(hugr, |h, _| self.run_no_validate(h)) + } + + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + let needed = self.find_needed_nodes(&*hugr); + let remove = hugr + .nodes() + .filter(|n| !needed.contains(n)) + .collect::>(); + for n in remove { + hugr.remove_node(n); + } + Ok(()) + } + + // "Diverge" aka "never-terminate" + // TODO would be more efficient to compute this bottom-up and cache (dynamic programming) + fn must_preserve(&self, h: &impl HugrView, n: Node) -> bool { + match self + .preserve_callback + .as_ref() + .map_or(PreserveNode::UseDefault, |f| f(h.base_hugr(), n)) + { + PreserveNode::MustKeep => return true, + PreserveNode::CanRemove => return false, + PreserveNode::UseDefault => { + match h.get_optype(n) { + OpType::CFG(_) => { + // TODO if the CFG has no cycles (that are possible given predicates) + // then we could say it definitely terminates (i.e. return false) + return true; + } + OpType::TailLoop(_) => { + // If the TailLoop never continues, clearly it doesn't terminate, but we haven't got + // dataflow results to tell us that. Instead rely on an earlier pass having rewritten + // such a TailLoop into a non-loop. + // Even just an upper-bound on the number of iterations would allow returning false. + return true; + } + OpType::Call(_) => { + // We could scan the target FuncDefn, but that might contain calls to itself, so we'd need + // a "seen" set...instead just rely on calls being inlined if we want to remove them. + return true; + } + _ => (), // fall through to check children + } + } + PreserveNode::RemoveIfAllChildrenCanBeRemoved => (), // fall through to check children + } + + // Node does not introduce non-termination, but still non-terminates if any of its children does + h.children(n).any(|ch| self.must_preserve(h, ch)) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use hugr_core::builder::{CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder}; + use hugr_core::extension::prelude::{usize_t, ConstUsize, PRELUDE_ID}; + use hugr_core::ops::handle::NodeHandle; + use hugr_core::ops::{OpTag, OpTrait}; + use hugr_core::types::Signature; + use hugr_core::HugrView; + use hugr_core::{ops::Value, type_row}; + use itertools::Itertools; + + use crate::dead_code::PreserveNode; + + use super::DeadCodeElimPass; + + #[test] + fn test_cfg_callback() { + let mut cb = + CFGBuilder::new(Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID)) + .unwrap(); + let cst_unused = cb.add_constant(Value::from(ConstUsize::new(3))); + let cst_used_in_dfg = cb.add_constant(Value::from(ConstUsize::new(5))); + let cst_used = cb.add_constant(Value::unary_unit_sum()); + let mut block = cb.entry_builder([type_row![]], type_row![]).unwrap(); + let mut dfg_unused = block + .dfg_builder(Signature::new(type_row![], usize_t()), []) + .unwrap(); + let lc_unused = dfg_unused.load_const(&cst_unused); + let lc1 = dfg_unused.load_const(&cst_used_in_dfg); + let dfg_unused = dfg_unused.finish_with_outputs([lc1]).unwrap().node(); + let pred = block.load_const(&cst_used); + let block = block.finish_with_outputs(pred, []).unwrap(); + let exit = cb.exit_block(); + cb.branch(&block, 0, &exit).unwrap(); + let orig = cb.finish_hugr().unwrap(); + + // Callbacks that allow removing the DFG (and cst_unused) + for dce in [ + DeadCodeElimPass::default(), + // keep the node inside the DFG, but remove the DFG without checking its children: + DeadCodeElimPass::default().set_preserve_callback(Arc::new(move |h, n| { + if n == dfg_unused || h.get_optype(n).is_const() { + PreserveNode::CanRemove + } else { + PreserveNode::MustKeep + } + })), + ] { + let mut h = orig.clone(); + dce.run(&mut h).unwrap(); + assert_eq!( + h.children(h.root()).collect_vec(), + [block.node(), exit.node(), cst_used.node()] + ); + assert_eq!( + h.children(block.node()) + .map(|n| h.get_optype(n).tag()) + .collect_vec(), + [OpTag::Input, OpTag::Output, OpTag::LoadConst] + ); + } + + // Callbacks that prevent removing any node... + fn keep_if(b: bool) -> PreserveNode { + if b { + PreserveNode::MustKeep + } else { + PreserveNode::UseDefault + } + } + for dce in [ + DeadCodeElimPass::default() + .set_preserve_callback(Arc::new(|_, _| PreserveNode::MustKeep)), + // keeping the unused node in the DFG, means keeping the DFG (which uses its other children) + DeadCodeElimPass::default() + .set_preserve_callback(Arc::new(move |_, n| keep_if(n == lc_unused.node()))), + ] { + let mut h = orig.clone(); + dce.run(&mut h).unwrap(); + assert_eq!(orig, h); + } + + // Callbacks that keep the DFG but allow removing the unused constant + for dce in [ + DeadCodeElimPass::default() + .set_preserve_callback(Arc::new(move |_, n| keep_if(n == dfg_unused))), + DeadCodeElimPass::default() + .set_preserve_callback(Arc::new(move |_, n| keep_if(n == lc1.node()))), + ] { + let mut h = orig.clone(); + dce.run(&mut h).unwrap(); + assert_eq!( + h.children(h.root()).collect_vec(), + [ + block.node(), + exit.node(), + cst_used_in_dfg.node(), + cst_used.node() + ] + ); + assert_eq!( + h.children(block.node()).skip(2).collect_vec(), + [dfg_unused, pred.node()] + ); + assert_eq!( + h.children(dfg_unused.node()) + .map(|n| h.get_optype(n).tag()) + .collect_vec(), + [OpTag::Input, OpTag::Output, OpTag::LoadConst] + ); + } + + // Callback that allows removing the DFG but require keeping cst_unused + { + let cst_unused = cst_unused.node(); + let mut h = orig.clone(); + DeadCodeElimPass::default() + .set_preserve_callback(Arc::new(move |_, n| keep_if(n == cst_unused))) + .run(&mut h) + .unwrap(); + assert_eq!( + h.children(h.root()).collect_vec(), + [block.node(), exit.node(), cst_unused, cst_used.node()] + ); + assert_eq!( + h.children(block.node()) + .map(|n| h.get_optype(n).tag()) + .collect_vec(), + [OpTag::Input, OpTag::Output, OpTag::LoadConst] + ); + } + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 026e817c1..09d349ca8 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -3,6 +3,8 @@ pub mod call_graph; pub mod const_fold; pub mod dataflow; +pub mod dead_code; +pub use dead_code::DeadCodeElimPass; mod dead_funcs; pub use dead_funcs::{remove_dead_funcs, RemoveDeadFuncsError, RemoveDeadFuncsPass}; pub mod force_order; @@ -10,6 +12,7 @@ mod half_node; pub mod lower; pub mod merge_bbs; mod monomorphize; +use hugr_core::{HugrView, Node}; // TODO: Deprecated re-export. Remove on a breaking release. #[deprecated( since = "0.14.1", @@ -31,3 +34,15 @@ pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; + +fn find_main(h: &impl HugrView) -> Option { + let root = h.root(); + if !h.get_optype(root).is_module() { + return None; + } + h.children(root).find(|n| { + h.get_optype(*n) + .as_func_defn() + .is_some_and(|f| f.name == "main") + }) +}