Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add separate DCE pass #1902

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 35 additions & 106 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! An (example) use of the [dataflow analysis framework](super::dataflow).

pub mod value_handle;
use std::collections::{HashMap, HashSet, VecDeque};
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;

use hugr_core::{
Expand All @@ -13,18 +13,20 @@ use hugr_core::{
},
ops::{
constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant,
OpType, Value,
Value,
},
types::{EdgeKind, TypeArg},
HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire,
};
use value_handle::ValueHandle;

use crate::dataflow::{
partial_from_const, AbstractValue, AnalysisResults, ConstLoader, ConstLocation, DFContext,
Machine, PartialValue, TailLoopTermination,
partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue,
TailLoopTermination,
};
use crate::dead_code::PreserveNode;
use crate::validation::{ValidatePassError, ValidationLevel};
use crate::{find_main, DeadCodeElimPass};

#[derive(Debug, Clone, Default)]
/// A configuration for the Constant Folding pass.
Expand Down Expand Up @@ -88,22 +90,15 @@ impl ConstantFoldPass {
});

let results = Machine::new(&hugr).run(ConstFoldContext(hugr), inputs);
let keep_nodes = self.find_needed_nodes(&results);
let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i);

let remove_nodes = hugr
let wires_to_break = hugr
.nodes()
.filter(|n| !keep_nodes.contains(n))
.collect::<HashSet<_>>();
let wires_to_break = keep_nodes
.into_iter()
.flat_map(|n| hugr.node_inputs(n).map(move |ip| (n, ip)))
.filter(|(n, ip)| {
*n != hugr.root()
&& matches!(hugr.get_optype(*n).port_kind(*ip), Some(EdgeKind::Value(_)))
})
// Note we COULD filter out (avoid breaking) wires from other nodes that we are keeping.
// This would insert fewer constants, but potentially expose less parallelism.
.filter_map(|(n, ip)| {
let (src, outp) = hugr.single_linked_output(n, ip).unwrap();
// Avoid breaking edges from existing LoadConstant (we'd only add another)
Expand All @@ -118,20 +113,42 @@ impl ConstantFoldPass {
))
})
.collect::<Vec<_>>();
// Sadly the results immutably borrow the hugr, so we must extract everything we need before mutation
let terminating_tail_loops = hugr
.nodes()
.filter(|n| {
results.tail_loop_terminates(*n) == Some(TailLoopTermination::NeverContinues)
})
.collect::<Vec<_>>();

for (n, import, v) in wires_to_break {
for (n, inport, v) in wires_to_break {
let parent = hugr.get_parent(n).unwrap();
let datatype = v.get_type();
// We could try hash-consing identical Consts, but not ATM
let cst = hugr.add_node_with_parent(parent, Const::new(v));
let lcst = hugr.add_node_with_parent(parent, LoadConstant { datatype });
hugr.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0));
hugr.disconnect(n, import);
hugr.connect(lcst, OutgoingPort::from(0), n, import);
}
for n in remove_nodes {
hugr.remove_node(n);
hugr.disconnect(n, inport);
hugr.connect(lcst, OutgoingPort::from(0), n, inport);
}
// Dataflow analysis applies our inputs to the 'main' function if this is a Module, so do the same here
DeadCodeElimPass::default()
.with_entry_points(hugr.get_optype(hugr.root()).is_module().then(
// No main => remove everything, so not much use
|| find_main(hugr).unwrap(),
))
.set_preserve_callback(if self.allow_increase_termination {
Arc::new(|_, _| PreserveNode::CanRemove)
} else {
Arc::new(move |_, n| {
if terminating_tail_loops.contains(&n) {
PreserveNode::RemoveIfAllChildrenCanBeRemoved
} else {
PreserveNode::UseDefault
}
})
})
.run(hugr)?;
Ok(())
}

Expand All @@ -140,94 +157,6 @@ impl ConstantFoldPass {
self.validation
.run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr))
}

fn find_needed_nodes<H: HugrView>(
&self,
results: &AnalysisResults<ValueHandle, H>,
) -> HashSet<Node> {
let mut needed = HashSet::new();
let h = results.hugr();
let mut q = VecDeque::from_iter([h.root()]);
while let Some(n) = q.pop_front() {
if !needed.insert(n) {
continue;
};
if h.get_optype(n).is_module() {
for ch in h.children(n) {
match h.get_optype(ch) {
OpType::AliasDecl(_) | OpType::AliasDefn(_) => {
// Use of these is done via names, rather than following edges.
// We could track these as well but for now be conservative.
q.push_back(ch);
}
OpType::FuncDefn(f) if f.name == "main" => {
// Dataflow analysis will have applied any inputs the 'main' function, so assume reachable.
q.push_back(ch);
}
_ => (),
}
}
} else if h.get_optype(n).is_cfg() {
for bb in h.children(n) {
//if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates
q.push_back(bb);
}
} else if let Some(inout) = h.get_io(n) {
// Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges.
q.extend(inout); // Input also necessary for legality even if unreachable

if !self.allow_increase_termination {
// Also add on anything that might not terminate (even if results not required -
// if its results are required we'll add it by following dataflow, below.)
for ch in h.children(n) {
if might_diverge(results, ch) {
q.push_back(ch);
}
}
}
}
// Also follow dataflow demand
for (src, op) in h.all_linked_outputs(n) {
let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() {
EdgeKind::Value(_) => {
h.get_optype(src).is_load_constant()
|| results
.try_read_wire_concrete::<Value, _, _>(Wire::new(src, op))
.is_err()
}
EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true,
EdgeKind::ControlFlow => false, // we always include all children of a CFG above
_ => true, // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst
};
if needs_predecessor {
q.push_back(src);
}
}
}
needed
}
}

// "Diverge" aka "never-terminate"
// TODO would be more efficient to compute this bottom-up and cache (dynamic programming)
fn might_diverge<V: AbstractValue>(results: &AnalysisResults<V, impl HugrView>, n: Node) -> bool {
let op = results.hugr().get_optype(n);
if op.is_cfg() {
// TODO if the CFG has no cycles (that are possible given predicates)
// then we could say it definitely terminates (i.e. return false)
true
} else if op.is_tail_loop()
&& results.tail_loop_terminates(n).unwrap() != TailLoopTermination::NeverContinues
{
// If we can even figure out the number of iterations is bounded that would allow returning false.
true
} else {
// Node does not introduce non-termination, but still non-terminates if any of its children does
results
.hugr()
.children(n)
.any(|ch| might_diverge(results, ch))
}
}

/// Exhaustively apply constant folding to a HUGR.
Expand Down
9 changes: 3 additions & 6 deletions hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use hugr_core::extension::prelude::{MakeTuple, UnpackTuple};
use hugr_core::ops::{OpTrait, OpType, TailLoop};
use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire};

use crate::find_main;

use super::value_row::ValueRow;
use super::{
partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext,
Expand Down Expand Up @@ -83,12 +85,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
// we must find the corresponding Input node.
let input_node_parent = match self.0.get_optype(root) {
OpType::Module(_) => {
let main = self.0.children(root).find(|n| {
self.0
.get_optype(*n)
.as_func_defn()
.is_some_and(|f| f.name == "main")
});
let main = find_main(&self.0);
if main.is_none() && in_values.next().is_some() {
panic!("Cannot give inputs to module with no 'main'");
}
Expand Down
Loading
Loading