diff --git a/src/algorithm/nest_cfgs.rs b/src/algorithm/nest_cfgs.rs index 80b4ee8c6..b124319f1 100644 --- a/src/algorithm/nest_cfgs.rs +++ b/src/algorithm/nest_cfgs.rs @@ -399,10 +399,8 @@ pub(crate) mod test { use super::*; use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder}; use crate::hugr::region::{FlatRegionView, Region}; - use crate::ops::{ - handle::{BasicBlockID, ConstID, NodeHandle}, - ConstValue, - }; + use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; + use crate::ops::Const; use crate::types::{ClassicType, SimpleType}; use crate::{type_row, Hugr}; const NAT: SimpleType = SimpleType::Classic(ClassicType::i64()); @@ -428,8 +426,8 @@ pub(crate) mod test { // \-> right -/ \-<--<-/ let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?; - let pred_const = cfg_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(ConstValue::simple_unary_predicate())?; + let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 1)?, @@ -647,8 +645,8 @@ pub(crate) mod test { separate: bool, ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?; - let pred_const = cfg_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(ConstValue::simple_unary_predicate())?; + let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 2)?, @@ -682,8 +680,8 @@ pub(crate) mod test { let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?; - let pred_const = cfg_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(ConstValue::simple_unary_predicate())?; + let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 1)?, diff --git a/src/builder.rs b/src/builder.rs index e8865bafe..17b6f81b0 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -5,8 +5,8 @@ use thiserror::Error; #[cfg(feature = "pyo3")] use pyo3::prelude::*; -use crate::hugr::typecheck::ConstTypeError; use crate::hugr::{HugrError, Node, ValidationError, Wire}; +use crate::ops::constant::typecheck::ConstTypeError; use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID}; use crate::types::SimpleType; diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 7fe78e16d..2c58793ae 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -1,7 +1,7 @@ use crate::hugr::validate::InterGraphEdgeError; use crate::hugr::view::HugrView; use crate::hugr::{Node, NodeMetadata, Port, ValidationError}; -use crate::ops::{self, ConstValue, LeafOp, OpTrait, OpType}; +use crate::ops::{self, LeafOp, OpTrait, OpType}; use std::iter; @@ -70,11 +70,10 @@ pub trait Container { /// /// This function will return an error if there is an error in adding the /// [`OpType::Const`] node. - fn add_constant(&mut self, val: ConstValue) -> Result { - let typ = val.const_type(); - let const_n = self.add_child_op(ops::Const::new(val).map_err(BuildError::BadConstant)?)?; + fn add_constant(&mut self, constant: ops::Const) -> Result { + let const_n = self.add_child_op(constant)?; - Ok((const_n, typ).into()) + Ok(const_n.into()) } /// Add a [`ops::FuncDefn`] node and returns a builder to define the function @@ -334,13 +333,17 @@ pub trait Dataflow: Container { /// This function will return an error if there is an error when adding the node. fn load_const(&mut self, cid: &ConstID) -> Result { let const_node = cid.node(); + let op: ops::Const = self + .hugr() + .get_optype(const_node) + .clone() + .try_into() + .expect("ConstID does not refer to Const op."); - let op: OpType = ops::LoadConstant { - datatype: cid.const_type(), - } - .into(); let load_n = self.add_dataflow_op( - op, + ops::LoadConstant { + datatype: op.const_type().clone(), + }, // Constant wire from the constant value node vec![Wire::new(const_node, Port::new_outgoing(0))], )?; @@ -353,8 +356,8 @@ pub trait Dataflow: Container { /// # Errors /// /// This function will return an error if there is an error when adding the node. - fn add_load_const(&mut self, val: ConstValue) -> Result { - let cid = self.add_constant(val)?; + fn add_load_const(&mut self, constant: ops::Const) -> Result { + let cid = self.add_constant(constant)?; self.load_const(&cid) } diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 59fd54118..574f32eea 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -304,7 +304,7 @@ mod test { use crate::builder::{DataflowSubContainer, ModuleBuilder}; use crate::macros::classic_row; use crate::types::ClassicType; - use crate::{builder::test::NAT, ops::ConstValue, type_row}; + use crate::{builder::test::NAT, type_row}; use cool_asserts::assert_matches; use super::*; @@ -361,7 +361,7 @@ mod test { }; let mut middle_b = cfg_builder.simple_block_builder(type_row![NAT], type_row![NAT], 1)?; let middle = { - let c = middle_b.add_load_const(ConstValue::simple_unary_predicate())?; + let c = middle_b.add_load_const(ops::Const::simple_unary_predicate())?; let [inw] = middle_b.input_wires_arr(); middle_b.finish_with_outputs(c, [inw])? }; diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index e35c746e7..3fd6ae4b2 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -206,12 +206,13 @@ mod test { use cool_asserts::assert_matches; use crate::builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder}; + use crate::{ builder::{ test::{n_identity, NAT}, Dataflow, }, - ops::ConstValue, + ops::Const, type_row, }; @@ -225,7 +226,6 @@ mod test { n_identity(conditional_b.case_builder(0)?)?; n_identity(conditional_b.case_builder(1)?)?; - Ok(()) } @@ -237,7 +237,7 @@ mod test { "main", AbstractSignature::new_df(type_row![NAT], type_row![NAT]).pure(), )?; - let tru_const = fbuild.add_constant(ConstValue::true_val())?; + let tru_const = fbuild.add_constant(Const::true_val())?; let _fdef = { let const_wire = fbuild.load_const(&tru_const)?; let [int] = fbuild.input_wires_arr(); diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 2a7fc0b65..d0de055a9 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -100,7 +100,7 @@ mod test { }, classic_row, hugr::ValidationError, - ops::ConstValue, + ops::Const, type_row, types::ClassicType, Hugr, @@ -112,7 +112,7 @@ mod test { let build_result: Result = { let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![ClassicType::i64()])?; let [i1] = loop_b.input_wires_arr(); - let const_wire = loop_b.add_load_const(ConstValue::i64(1))?; + let const_wire = loop_b.add_load_const(Const::i64(1)?)?; let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?; loop_b.set_outputs(break_wire, [i1])?; @@ -140,7 +140,7 @@ mod test { classic_row![ClassicType::i64()], )?; let signature = loop_b.loop_signature()?.clone(); - let const_wire = loop_b.add_load_const(ConstValue::true_val())?; + let const_wire = loop_b.add_load_const(Const::true_val())?; let [b1] = loop_b.input_wires_arr(); let conditional_id = { let predicate_inputs = vec![type_row![]; 2]; @@ -160,7 +160,7 @@ mod test { let mut branch_1 = conditional_b.case_builder(1)?; let [_b1] = branch_1.input_wires_arr(); - let wire = branch_1.add_load_const(ConstValue::i64(2))?; + let wire = branch_1.add_load_const(Const::i64(2)?)?; let break_wire = branch_1.make_break(signature, [wire])?; branch_1.finish_with_outputs([break_wire])?; diff --git a/src/hugr.rs b/src/hugr.rs index de9c6d230..d3be07ad9 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -5,7 +5,6 @@ mod hugrmut; pub mod region; pub mod rewrite; pub mod serialize; -pub mod typecheck; pub mod validate; pub mod view; diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 771d75882..4ba18b03a 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -7,7 +7,8 @@ use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; use crate::hugr::rewrite::Rewrite; use crate::hugr::{HugrMut, HugrView}; -use crate::ops::{BasicBlock, ConstValue, OpTag, OpTrait, OpType}; +use crate::ops; +use crate::ops::{BasicBlock, OpTag, OpTrait, OpType}; use crate::{type_row, Hugr, Node}; /// Moves part of a Control-flow Sibling Graph into a new CFG-node @@ -108,7 +109,7 @@ impl Rewrite for OutlineCfg { cfg.exit_block(); // Makes inner exit block (but no entry block) let cfg_outputs = cfg.finish_sub_container().unwrap().outputs(); let predicate = new_block_bldr - .add_constant(ConstValue::simple_predicate(0, 1)) + .add_constant(ops::Const::simple_unary_predicate()) .unwrap(); let pred_wire = new_block_bldr.load_const(&predicate).unwrap(); let new_block_hugr = new_block_bldr diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index c5d00d9f3..11d50145c 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -813,7 +813,7 @@ mod test { parent: Node, predicate_size: usize, ) -> (Node, Node, Node, Node) { - let const_op = ops::Const::new(ConstValue::simple_predicate(0, predicate_size)).unwrap(); + let const_op = ops::Const::simple_predicate(0, predicate_size); let tag_type = SimpleType::Classic(ClassicType::new_simple_predicate(predicate_size)); let input = b @@ -1148,7 +1148,7 @@ mod test { // Second input of Xor from a constant let cst = h.add_op_with_parent( h.root(), - ops::Const::new(ConstValue::Int { width: 1, value: 1 }).unwrap(), + ops::Const::new(ConstValue::Int(1), ClassicType::int::<1>()).unwrap(), )?; let lcst = h.add_op_with_parent( h.root(), diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 437b9ac5a..247659bd5 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -3,33 +3,95 @@ use std::any::Any; use crate::{ - classic_row, - hugr::typecheck::{typecheck_const, ConstTypeError}, macros::impl_box_clone, - types::{ClassicRow, ClassicType, Container, CustomType, EdgeKind, HashableType}, + types::{ClassicRow, ClassicType, CustomType, EdgeKind}, }; use downcast_rs::{impl_downcast, Downcast}; use smol_str::SmolStr; +use self::typecheck::{typecheck_const, ConstTypeError}; + use super::OpTag; use super::{OpName, OpTrait, StaticTag}; +pub mod typecheck; /// A constant value definition. -#[derive(Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize)] -pub struct Const(ConstValue); +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct Const { + value: ConstValue, + typ: ClassicType, +} impl Const { /// Creates a new Const, type-checking the value. - pub fn new(val: ConstValue) -> Result { - typecheck_const(&val.const_type(), &val)?; - Ok(Const(val)) + pub fn new(value: ConstValue, typ: ClassicType) -> Result { + value.check_type(&typ)?; + Ok(Self { value, typ }) + } + + /// Returns a reference to the value of this [`Const`]. + pub fn value(&self) -> &ConstValue { + &self.value + } + + /// Returns a reference to the type of this [`Const`]. + pub fn const_type(&self) -> &ClassicType { + &self.typ + } + + /// Sum of Tuples, used as predicates in branching. + /// Tuple rows are defined in order by input rows. + pub fn predicate( + tag: usize, + value: ConstValue, + variant_rows: impl IntoIterator, + ) -> Result { + let typ = ClassicType::new_predicate(variant_rows); + + Self::new(ConstValue::Sum(tag, Box::new(value)), typ) + } + + /// Constant Sum over units, used as predicates. + pub fn simple_predicate(tag: usize, size: usize) -> Self { + Self { + value: ConstValue::simple_predicate(tag), + typ: ClassicType::new_simple_predicate(size), + } + } + + /// Constant Sum over units, with only one variant. + pub fn simple_unary_predicate() -> Self { + Self { + value: ConstValue::simple_unary_predicate(), + typ: ClassicType::new_simple_predicate(1), + } + } + + /// Constant "true" value, i.e. the second variant of Sum((), ()). + pub fn true_val() -> Self { + Self::simple_predicate(1, 2) + } + + /// Constant "false" value, i.e. the first variant of Sum((), ()). + pub fn false_val() -> Self { + Self::simple_predicate(0, 2) + } + + /// Fixed width integer + pub fn int(value: HugrIntValueStore) -> Result { + Self::new(ConstValue::Int(value), ClassicType::int::()) + } + + /// 64-bit integer + pub fn i64(value: i64) -> Result { + Self::int::<64>(value as HugrIntValueStore) } } impl OpName for Const { fn name(&self) -> SmolStr { - self.0.name() + self.value.name() } } impl StaticTag for Const { @@ -37,7 +99,7 @@ impl StaticTag for Const { } impl OpTrait for Const { fn description(&self) -> &str { - self.0.description() + self.value.description() } fn tag(&self) -> OpTag { @@ -45,7 +107,7 @@ impl OpTrait for Const { } fn other_output(&self) -> Option { - Some(EdgeKind::Static(self.0.const_type())) + Some(EdgeKind::Static(self.typ.clone())) } } @@ -63,20 +125,11 @@ pub(crate) const HUGR_MAX_INT_WIDTH: HugrIntWidthStore = #[allow(missing_docs)] pub enum ConstValue { /// An arbitrary length integer constant. - Int { - value: HugrIntValueStore, - width: HugrIntWidthStore, - }, + Int(HugrIntValueStore), /// Double precision float F64(f64), /// A constant specifying a variant of a Sum type. - Sum { - tag: usize, - // We require the type to be entirely Classic (i.e. we don't allow - // a classic variant of a Sum with other variants that are linear) - variants: ClassicRow, - val: Box, - }, + Sum(usize, Box), /// A tuple of constant values. Tuple(Vec), /// An opaque constant value, with cached type @@ -92,34 +145,22 @@ impl PartialEq for dyn CustomConst { impl Default for ConstValue { fn default() -> Self { - Self::Int { - value: 0, - width: 64, - } + Self::Int(0) } } impl ConstValue { /// Returns the datatype of the constant. - pub fn const_type(&self) -> ClassicType { - match self { - Self::Int { value: _, width } => HashableType::Int(*width).into(), - Self::Opaque((_, b)) => Container::Opaque((*b).custom_type()).into(), - Self::Sum { variants, .. } => ClassicType::new_sum(variants.clone()), - Self::Tuple(vals) => { - let row: Vec<_> = vals.iter().map(|val| val.const_type()).collect(); - ClassicType::new_tuple(row) - } - Self::F64(_) => ClassicType::F64, - } + pub fn check_type(&self, typ: &ClassicType) -> Result<(), ConstTypeError> { + typecheck_const(typ, self) } /// Unique name of the constant. pub fn name(&self) -> SmolStr { match self { - Self::Int { value, width } => format!("const:int<{width}>:{value}"), + Self::Int(value) => format!("const:int{value}"), Self::F64(f) => format!("const:float:{f}"), Self::Opaque((_, v)) => format!("const:{}", v.name()), - Self::Sum { tag, val, .. } => { + Self::Sum(tag, val) => { format!("const:sum:{{tag:{tag}, val:{}}}", val.name()) } Self::Tuple(vals) => { @@ -141,49 +182,19 @@ impl ConstValue { ConstValue::Tuple(vec![]) } - /// Constant "true" value, i.e. the second variant of Sum((), ()). - pub fn true_val() -> Self { - Self::simple_predicate(1, 2) - } - - /// Constant "false" value, i.e. the first variant of Sum((), ()). - pub fn false_val() -> Self { - Self::simple_predicate(0, 2) - } - /// Constant Sum over units, used as predicates. - pub fn simple_predicate(tag: usize, size: usize) -> Self { - Self::predicate( - tag, - Self::unit(), - std::iter::repeat(classic_row![]).take(size), - ) + pub fn simple_predicate(tag: usize) -> Self { + Self::predicate(tag, Self::unit()) } /// Constant Sum over Tuples, used as predicates. - pub fn predicate( - tag: usize, - val: ConstValue, - variant_rows: impl IntoIterator, - ) -> Self { - ConstValue::Sum { - tag, - variants: ClassicRow::predicate_variants_row(variant_rows), - val: Box::new(val), - } + pub fn predicate(tag: usize, val: ConstValue) -> Self { + ConstValue::Sum(tag, Box::new(val)) } /// Constant Sum over Tuples with just one variant of unit type pub fn simple_unary_predicate() -> Self { - Self::simple_predicate(0, 1) - } - - /// New 64 bit integer constant - pub fn i64(value: i64) -> Self { - Self::Int { - value: value as HugrIntValueStore, - width: 64, - } + Self::simple_predicate(0) } } @@ -220,12 +231,13 @@ impl_box_clone!(CustomConst, CustomConstBoxClone); #[cfg(test)] mod test { + use cool_asserts::assert_matches; + use super::ConstValue; + use super::{typecheck::ConstTypeError, Const}; use crate::{ builder::{BuildError, Container, DFGBuilder, Dataflow, DataflowHugr}, - classic_row, - hugr::typecheck::ConstTypeError, - type_row, + classic_row, type_row, types::{ClassicType, SimpleRow, SimpleType}, }; @@ -238,20 +250,16 @@ mod test { let pred_ty = SimpleType::new_predicate(pred_rows.clone()); let mut b = DFGBuilder::new(type_row![], SimpleRow::from(vec![pred_ty.clone()]))?; - let c = b.add_constant(ConstValue::predicate( + let c = b.add_constant(Const::predicate( 0, - ConstValue::Tuple(vec![ConstValue::i64(3), ConstValue::F64(3.15)]), + ConstValue::Tuple(vec![ConstValue::Int(3), ConstValue::F64(3.15)]), pred_rows.clone(), - ))?; + )?)?; let w = b.load_const(&c)?; b.finish_hugr_with_outputs([w]).unwrap(); let mut b = DFGBuilder::new(type_row![], SimpleRow::from(vec![pred_ty]))?; - let c = b.add_constant(ConstValue::predicate( - 1, - ConstValue::Tuple(vec![]), - pred_rows, - ))?; + let c = b.add_constant(Const::predicate(1, ConstValue::unit(), pred_rows)?)?; let w = b.load_const(&c)?; b.finish_hugr_with_outputs([w]).unwrap(); @@ -264,17 +272,8 @@ mod test { classic_row![ClassicType::i64(), ClassicType::F64], type_row![], ]; - let pred_ty = SimpleType::new_predicate(pred_rows.clone()); - let mut b = DFGBuilder::new(type_row![], SimpleRow::from(vec![pred_ty])).unwrap(); - let res = b.add_constant(ConstValue::predicate( - 0, - ConstValue::Tuple(vec![]), - pred_rows, - )); - assert_eq!( - res, - Err(BuildError::BadConstant(ConstTypeError::TupleWrongLength)) - ); + let res = Const::predicate(0, ConstValue::Tuple(vec![]), pred_rows); + assert_matches!(res, Err(ConstTypeError::TupleWrongLength)); } } diff --git a/src/hugr/typecheck.rs b/src/ops/constant/typecheck.rs similarity index 75% rename from src/hugr/typecheck.rs rename to src/ops/constant/typecheck.rs index b7e2fa01b..6c2e2c2c1 100644 --- a/src/hugr/typecheck.rs +++ b/src/ops/constant/typecheck.rs @@ -5,11 +5,11 @@ use lazy_static::lazy_static; use std::collections::HashSet; -use crate::hugr::*; +use thiserror::Error; // For static typechecking use crate::ops::ConstValue; -use crate::types::{ClassicRow, ClassicType, Container, HashableType, PrimType, TypeRow}; +use crate::types::{ClassicType, Container, HashableType, PrimType, TypeRow}; use crate::ops::constant::{HugrIntValueStore, HugrIntWidthStore, HUGR_MAX_INT_WIDTH}; @@ -29,7 +29,7 @@ pub enum ConstIntError { } /// Errors that arise from typechecking constants -#[derive(Clone, Debug, Eq, PartialEq, Error)] +#[derive(Clone, Debug, PartialEq, Error)] pub enum ConstTypeError { /// This case hasn't been implemented. Possibly because we don't have value /// constructors to check against it @@ -51,12 +51,11 @@ pub enum ConstTypeError { #[error("Tag of Sum value is invalid")] InvalidSumTag, /// A mismatch between the type expected and the actual type of the constant - #[error("Type mismatch for const - expected {0}, found {1}")] + #[error("Type mismatch for const - expected {0}, found {1:?}")] TypeMismatch(ClassicType, ClassicType), - /// A mismatch between the embedded type and the type we're checking - /// against, as above, but for rows instead of simple types - #[error("Type mismatch for const - expected {0}, found {1}")] - TypeRowMismatch(ClassicRow, ClassicRow), + /// A mismatch between the type expected and the value. + #[error("Value {1:?} does not match expected type {0}")] + ValueCheckFail(ClassicType, ConstValue), } lazy_static! { @@ -116,14 +115,10 @@ fn map_vals( } /// Typecheck a constant value -pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstTypeError> { +pub(super) fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstTypeError> { match (typ, val) { - (ClassicType::Hashable(HashableType::Int(exp_width)), ConstValue::Int { value, width }) => { - if exp_width == width { - check_int_fits_in_width(*value, *width).map_err(ConstTypeError::Int) - } else { - Err(ConstTypeError::IntWidthMismatch(*exp_width, *width)) - } + (ClassicType::Hashable(HashableType::Int(exp_width)), ConstValue::Int(value)) => { + check_int_fits_in_width(*value, *exp_width).map_err(ConstTypeError::Int) } (ClassicType::F64, ConstValue::F64(_)) => Ok(()), (ty @ ClassicType::Container(c), tm) => match (c, tm) { @@ -136,25 +131,15 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT } Ok(()) } - (Container::Tuple(_), _) => { - Err(ConstTypeError::TypeMismatch(ty.clone(), tm.const_type())) - } - (Container::Sum(row), ConstValue::Sum { tag, variants, val }) => { - if tag > &row.len() { - return Err(ConstTypeError::InvalidSumTag); - } - if **row != *variants { - return Err(ConstTypeError::TypeRowMismatch( - *row.clone(), - variants.clone(), - )); + (Container::Tuple(_), _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), tm.clone())), + (Container::Sum(row), ConstValue::Sum(tag, val)) => { + if let Some(ty) = row.get(*tag) { + typecheck_const(ty, val.as_ref()) + } else { + Err(ConstTypeError::InvalidSumTag) } - let ty = variants.get(*tag).unwrap(); - typecheck_const(ty, val.as_ref()) - } - (Container::Sum(_), _) => { - Err(ConstTypeError::TypeMismatch(ty.clone(), tm.const_type())) } + (Container::Sum(_), _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), tm.clone())), (Container::Opaque(ty), ConstValue::Opaque((ty_act, _val))) => { if ty_act != ty { return Err(ConstTypeError::TypeMismatch( @@ -181,7 +166,7 @@ pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstT (ClassicType::Hashable(HashableType::Variable(_)), _) => { Err(ConstTypeError::ConstCantBeVar) } - (ty, _) => Err(ConstTypeError::TypeMismatch(ty.clone(), val.const_type())), + (ty, _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), val.clone())), } } @@ -196,39 +181,35 @@ mod test { #[test] fn test_typecheck_const() { const INT: ClassicType = ClassicType::int::<64>(); - typecheck_const(&INT, &ConstValue::i64(3)).unwrap(); - assert_eq!( - typecheck_const(&HashableType::Int(32).into(), &ConstValue::i64(3)), - Err(ConstTypeError::IntWidthMismatch(32, 64)) - ); + typecheck_const(&INT, &ConstValue::Int(3)).unwrap(); typecheck_const(&ClassicType::F64, &ConstValue::F64(17.4)).unwrap(); assert_eq!( - typecheck_const(&ClassicType::F64, &ConstValue::i64(5)), - Err(ConstTypeError::TypeMismatch( + typecheck_const(&ClassicType::F64, &ConstValue::Int(5)), + Err(ConstTypeError::ValueCheckFail( ClassicType::F64, - ClassicType::i64() + ConstValue::Int(5) )) ); let tuple_ty = ClassicType::new_tuple(classic_row![INT, ClassicType::F64,]); typecheck_const( &tuple_ty, - &ConstValue::Tuple(vec![ConstValue::i64(7), ConstValue::F64(5.1)]), + &ConstValue::Tuple(vec![ConstValue::Int(7), ConstValue::F64(5.1)]), ) .unwrap(); assert_matches!( typecheck_const( &tuple_ty, - &ConstValue::Tuple(vec![ConstValue::F64(4.8), ConstValue::i64(2)]) + &ConstValue::Tuple(vec![ConstValue::F64(4.8), ConstValue::Int(2)]) ), - Err(ConstTypeError::TypeMismatch(_, _)) + Err(ConstTypeError::ValueCheckFail(_, _)) ); assert_eq!( typecheck_const( &tuple_ty, &ConstValue::Tuple(vec![ - ConstValue::i64(5), + ConstValue::Int(5), ConstValue::F64(3.3), - ConstValue::i64(2) + ConstValue::Int(2) ]) ), Err(ConstTypeError::TupleWrongLength) diff --git a/src/ops/handle.rs b/src/ops/handle.rs index c73d7df19..753870f53 100644 --- a/src/ops/handle.rs +++ b/src/ops/handle.rs @@ -99,14 +99,7 @@ impl AliasID { #[derive(DerFrom, Debug, Clone, PartialEq, Eq)] /// Handle to a [Const](crate::ops::OpType::Const) node. -pub struct ConstID(Node, ClassicType); - -impl ConstID { - /// Return the type of the constant. - pub fn const_type(&self) -> ClassicType { - self.1.clone() - } -} +pub struct ConstID(Node); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [BasicBlock](crate::ops::BasicBlock) node. diff --git a/src/types/type_param.rs b/src/types/type_param.rs index ca97178da..d3c37aae8 100644 --- a/src/types/type_param.rs +++ b/src/types/type_param.rs @@ -6,7 +6,7 @@ use thiserror::Error; -use crate::hugr::typecheck::{check_int_fits_in_width, ConstIntError}; +use crate::ops::constant::typecheck::{check_int_fits_in_width, ConstIntError}; use crate::ops::constant::HugrIntValueStore; use super::{simple::Container, ClassicType, HashableType, PrimType, SimpleType, TypeTag};