Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into refactor/signature-re…
Browse files Browse the repository at this point in the history
…sources2
  • Loading branch information
croyzor committed Aug 1, 2023
2 parents 3e5c250 + 1117634 commit aa96d93
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 188 deletions.
18 changes: 8 additions & 10 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,8 @@ pub(crate) mod test {
use super::*;
use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder};
use crate::hugr::region::{FlatRegionView, Region};
use crate::ops::{
handle::{BasicBlockID, ConstID, NodeHandle},
ConstValue,
};
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::Const;
use crate::types::{ClassicType, SimpleType};
use crate::{type_row, Hugr};
const NAT: SimpleType = SimpleType::Classic(ClassicType::i64());
Expand All @@ -428,8 +426,8 @@ pub(crate) mod test {
// \-> right -/ \-<--<-/
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;

let pred_const = cfg_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(ConstValue::simple_unary_predicate())?;
let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1)?,
Expand Down Expand Up @@ -647,8 +645,8 @@ pub(crate) mod test {
separate: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let pred_const = cfg_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(ConstValue::simple_unary_predicate())?;
let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 2)?,
Expand Down Expand Up @@ -682,8 +680,8 @@ pub(crate) mod test {

let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;

let pred_const = cfg_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(ConstValue::simple_unary_predicate())?;
let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1)?,
Expand Down
2 changes: 1 addition & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

use crate::hugr::typecheck::ConstTypeError;
use crate::hugr::{HugrError, Node, ValidationError, Wire};
use crate::ops::constant::typecheck::ConstTypeError;
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::types::SimpleType;

Expand Down
27 changes: 15 additions & 12 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::view::HugrView;
use crate::hugr::{Node, NodeMetadata, Port, ValidationError};
use crate::ops::{self, ConstValue, LeafOp, OpTrait, OpType};
use crate::ops::{self, LeafOp, OpTrait, OpType};

use std::iter;

Expand Down Expand Up @@ -70,11 +70,10 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(&mut self, val: ConstValue) -> Result<ConstID, BuildError> {
let typ = val.const_type();
let const_n = self.add_child_op(ops::Const::new(val).map_err(BuildError::BadConstant)?)?;
fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> {
let const_n = self.add_child_op(constant)?;

Ok((const_n, typ).into())
Ok(const_n.into())
}

/// Add a [`ops::FuncDefn`] node and returns a builder to define the function
Expand Down Expand Up @@ -334,13 +333,17 @@ pub trait Dataflow: Container {
/// This function will return an error if there is an error when adding the node.
fn load_const(&mut self, cid: &ConstID) -> Result<Wire, BuildError> {
let const_node = cid.node();
let op: ops::Const = self
.hugr()
.get_optype(const_node)
.clone()
.try_into()
.expect("ConstID does not refer to Const op.");

let op: OpType = ops::LoadConstant {
datatype: cid.const_type(),
}
.into();
let load_n = self.add_dataflow_op(
op,
ops::LoadConstant {
datatype: op.const_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, Port::new_outgoing(0))],
)?;
Expand All @@ -353,8 +356,8 @@ pub trait Dataflow: Container {
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(&mut self, val: ConstValue) -> Result<Wire, BuildError> {
let cid = self.add_constant(val)?;
fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
self.load_const(&cid)
}

Expand Down
4 changes: 2 additions & 2 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ mod test {
use crate::builder::{DataflowSubContainer, ModuleBuilder};
use crate::macros::classic_row;
use crate::types::ClassicType;
use crate::{builder::test::NAT, ops::ConstValue, type_row};
use crate::{builder::test::NAT, type_row};
use cool_asserts::assert_matches;

use super::*;
Expand Down Expand Up @@ -361,7 +361,7 @@ mod test {
};
let mut middle_b = cfg_builder.simple_block_builder(type_row![NAT], type_row![NAT], 1)?;
let middle = {
let c = middle_b.add_load_const(ConstValue::simple_unary_predicate())?;
let c = middle_b.add_load_const(ops::Const::simple_unary_predicate())?;
let [inw] = middle_b.input_wires_arr();
middle_b.finish_with_outputs(c, [inw])?
};
Expand Down
6 changes: 3 additions & 3 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,13 @@ mod test {
use cool_asserts::assert_matches;

use crate::builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder};

use crate::{
builder::{
test::{n_identity, NAT},
Dataflow,
},
ops::ConstValue,
ops::Const,
type_row,
};

Expand All @@ -225,7 +226,6 @@ mod test {

n_identity(conditional_b.case_builder(0)?)?;
n_identity(conditional_b.case_builder(1)?)?;

Ok(())
}

Expand All @@ -237,7 +237,7 @@ mod test {
"main",
AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure(),
)?;
let tru_const = fbuild.add_constant(ConstValue::true_val())?;
let tru_const = fbuild.add_constant(Const::true_val())?;
let _fdef = {
let const_wire = fbuild.load_const(&tru_const)?;
let [int] = fbuild.input_wires_arr();
Expand Down
8 changes: 4 additions & 4 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ mod test {
},
classic_row,
hugr::ValidationError,
ops::ConstValue,
ops::Const,
type_row,
types::ClassicType,
Hugr,
Expand All @@ -112,7 +112,7 @@ mod test {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![ClassicType::i64()])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(ConstValue::i64(1))?;
let const_wire = loop_b.add_load_const(Const::i64(1)?)?;

let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
loop_b.set_outputs(break_wire, [i1])?;
Expand Down Expand Up @@ -140,7 +140,7 @@ mod test {
classic_row![ClassicType::i64()],
)?;
let signature = loop_b.loop_signature()?.clone();
let const_wire = loop_b.add_load_const(ConstValue::true_val())?;
let const_wire = loop_b.add_load_const(Const::true_val())?;
let [b1] = loop_b.input_wires_arr();
let conditional_id = {
let predicate_inputs = vec![type_row![]; 2];
Expand All @@ -160,7 +160,7 @@ mod test {
let mut branch_1 = conditional_b.case_builder(1)?;
let [_b1] = branch_1.input_wires_arr();

let wire = branch_1.add_load_const(ConstValue::i64(2))?;
let wire = branch_1.add_load_const(Const::i64(2)?)?;
let break_wire = branch_1.make_break(signature, [wire])?;
branch_1.finish_with_outputs([break_wire])?;

Expand Down
1 change: 0 additions & 1 deletion src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ mod hugrmut;
pub mod region;
pub mod rewrite;
pub mod serialize;
pub mod typecheck;
pub mod validate;
pub mod view;

Expand Down
5 changes: 3 additions & 2 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use thiserror::Error;
use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
use crate::hugr::rewrite::Rewrite;
use crate::hugr::{HugrMut, HugrView};
use crate::ops::{BasicBlock, ConstValue, OpTag, OpTrait, OpType};
use crate::ops;
use crate::ops::{BasicBlock, OpTag, OpTrait, OpType};
use crate::{type_row, Hugr, Node};

/// Moves part of a Control-flow Sibling Graph into a new CFG-node
Expand Down Expand Up @@ -108,7 +109,7 @@ impl Rewrite for OutlineCfg {
cfg.exit_block(); // Makes inner exit block (but no entry block)
let cfg_outputs = cfg.finish_sub_container().unwrap().outputs();
let predicate = new_block_bldr
.add_constant(ConstValue::simple_predicate(0, 1))
.add_constant(ops::Const::simple_unary_predicate())
.unwrap();
let pred_wire = new_block_bldr.load_const(&predicate).unwrap();
let new_block_hugr = new_block_bldr
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ mod test {
parent: Node,
predicate_size: usize,
) -> (Node, Node, Node, Node) {
let const_op = ops::Const::new(ConstValue::simple_predicate(0, predicate_size)).unwrap();
let const_op = ops::Const::simple_predicate(0, predicate_size);
let tag_type = SimpleType::Classic(ClassicType::new_simple_predicate(predicate_size));

let input = b
Expand Down Expand Up @@ -1148,7 +1148,7 @@ mod test {
// Second input of Xor from a constant
let cst = h.add_op_with_parent(
h.root(),
ops::Const::new(ConstValue::Int { width: 1, value: 1 }).unwrap(),
ops::Const::new(ConstValue::Int(1), ClassicType::int::<1>()).unwrap(),
)?;
let lcst = h.add_op_with_parent(
h.root(),
Expand Down
Loading

0 comments on commit aa96d93

Please sign in to comment.