Skip to content

Commit

Permalink
TailLoopTermination just examine whatever PartialValue's we have, rem…
Browse files Browse the repository at this point in the history
…ove most
  • Loading branch information
acl-cqc committed Sep 2, 2024
1 parent 41820f2 commit e54c742
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 129 deletions.
23 changes: 8 additions & 15 deletions hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use ascent::lattice::BoundedLattice;
use hugr_core::extension::prelude::{MakeTuple, UnpackTuple};
use std::collections::HashMap;
use std::hash::Hash;
Expand Down Expand Up @@ -108,15 +107,6 @@ ascent::ascent! {
if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1
for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields);

lattice tail_loop_termination(C,Node,TailLoopTermination);
tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <--
tail_loop_node(c,tl_n);
tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <--
tail_loop_node(c,tl_n),
io_node(c,tl,out_n, IO::Output),
in_wire_value(c, out_n, IncomingPort::from(0), v);


// Conditional
relation conditional_node(C, Node);
relation case_node(C,Node,usize, Node);
Expand Down Expand Up @@ -221,11 +211,14 @@ impl<V: AbstractValue, C: DFContext<V>> Machine<V, C> {

pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination {
assert!(hugr.get_optype(node).is_tail_loop());
self.0
.tail_loop_termination
.iter()
.find_map(|(_, n, v)| (n == &node).then_some(*v))
.unwrap()
let [_, out] = hugr.get_io(node).unwrap();
TailLoopTermination::from_control_value(
self.0
.in_wire_value
.iter()
.find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v))
.unwrap(),
)
}

pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion hugr-passes/src/dataflow/datalog/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn test_tail_loop_always_iterates() {
let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap();
assert_eq!(o_r2, PartialValue::bottom().into());
assert_eq!(
TailLoopTermination::bottom(),
TailLoopTermination::Bottom,
machine.tail_loop_terminates(&hugr, tail_loop.node())
)
}
Expand Down
115 changes: 2 additions & 113 deletions hugr-passes/src/dataflow/datalog/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@
// https://github.com/proptest-rs/proptest/issues/447
#![cfg_attr(test, allow(non_local_definitions))]

use std::cmp::Ordering;

use ascent::lattice::{BoundedLattice, Lattice};

use super::super::partial_value::{AbstractValue, PartialValue};
use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort};

#[cfg(test)]
use proptest_derive::Arbitrary;

impl<V: AbstractValue> Lattice for PartialValue<V> {
fn meet(self, other: Self) -> Self {
self.meet(other)
Expand Down Expand Up @@ -57,7 +52,6 @@ pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator<Item =
}

#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
#[cfg_attr(test, derive(Arbitrary))]
pub enum TailLoopTermination {
Bottom,
ExactlyZeroContinues,
Expand All @@ -70,114 +64,9 @@ impl TailLoopTermination {
if may_break && !may_continue {
Self::ExactlyZeroContinues
} else if may_break && may_continue {
Self::top()
Self::Top
} else {
Self::bottom()
}
}
}

impl PartialOrd for TailLoopTermination {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
if self == other {
return Some(std::cmp::Ordering::Equal);
};
match (self, other) {
(Self::Bottom, _) => Some(Ordering::Less),
(_, Self::Bottom) => Some(Ordering::Greater),
(Self::Top, _) => Some(Ordering::Greater),
(_, Self::Top) => Some(Ordering::Less),
_ => None,
}
}
}

impl Lattice for TailLoopTermination {
fn meet(mut self, other: Self) -> Self {
self.meet_mut(other);
self
}

fn join(mut self, other: Self) -> Self {
self.join_mut(other);
self
}

fn meet_mut(&mut self, other: Self) -> bool {
// let new_self = &mut self;
match (*self).partial_cmp(&other) {
Some(Ordering::Greater) => {
*self = other;
true
}
Some(_) => false,
_ => {
*self = Self::Bottom;
true
}
}
}

fn join_mut(&mut self, other: Self) -> bool {
match (*self).partial_cmp(&other) {
Some(Ordering::Less) => {
*self = other;
true
}
Some(_) => false,
_ => {
*self = Self::Top;
true
}
}
}
}

impl BoundedLattice for TailLoopTermination {
fn bottom() -> Self {
Self::Bottom
}

fn top() -> Self {
Self::Top
}
}

#[cfg(test)]
#[cfg_attr(test, allow(non_local_definitions))]
mod test {
use super::*;
use proptest::prelude::*;

proptest! {
#[test]
fn bounded_lattice(v: TailLoopTermination) {
prop_assert!(v <= TailLoopTermination::top());
prop_assert!(v >= TailLoopTermination::bottom());
}

#[test]
fn meet_join_self_noop(v1: TailLoopTermination) {
let mut subject = v1.clone();

assert_eq!(v1.clone(), v1.clone().join(v1.clone()));
assert!(!subject.join_mut(v1.clone()));
assert_eq!(subject, v1);

assert_eq!(v1.clone(), v1.clone().meet(v1.clone()));
assert!(!subject.meet_mut(v1.clone()));
assert_eq!(subject, v1);
}

#[test]
fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) {
let meet = v1.clone().meet(v2.clone());
prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet);
prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet);

let join = v1.clone().join(v2.clone());
prop_assert!(join >= v1, "join not >=: {:#?}", &join);
prop_assert!(join >= v2, "join not >=: {:#?}", &join);
Self::Bottom
}
}
}

0 comments on commit e54c742

Please sign in to comment.