diff --git a/jets-bench/benches/elements/data_structures.rs b/jets-bench/benches/elements/data_structures.rs index dd857578..f0c9ffba 100644 --- a/jets-bench/benches/elements/data_structures.rs +++ b/jets-bench/benches/elements/data_structures.rs @@ -4,9 +4,8 @@ use bitcoin::secp256k1; use elements::Txid; use rand::{thread_rng, RngCore}; -pub use simplicity::hashes::sha256; use simplicity::{ - bitcoin, elements, hashes::Hash, hex::FromHex, types::Type, BitIter, Error, Value, + bitcoin, elements, hashes::Hash, hex::FromHex, types::{self, Type}, BitIter, Error, Value, }; use std::sync::Arc; @@ -57,7 +56,8 @@ pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result, Erro assert!(n < 16); assert!(v.len() < (1 << (n + 1))); let mut iter = BitIter::new(v.iter().copied()); - let types = Type::powers_of_two(n); // size n + 1 + let ctx = types::Context::new(); + let types = Type::powers_of_two(&ctx, n); // size n + 1 let mut res = None; while n > 0 { let v = if v.len() >= (1 << (n + 1)) { diff --git a/jets-bench/benches/elements/main.rs b/jets-bench/benches/elements/main.rs index efc3f6e4..25761d5c 100644 --- a/jets-bench/benches/elements/main.rs +++ b/jets-bench/benches/elements/main.rs @@ -93,8 +93,8 @@ impl ElementsBenchEnvType { } fn jet_arrow(jet: Elements) -> (Arc, Arc) { - let src_ty = jet.source_ty().to_type().final_data().unwrap(); - let tgt_ty = jet.target_ty().to_type().final_data().unwrap(); + let src_ty = jet.source_ty().to_final(); + let tgt_ty = jet.target_ty().to_final(); (src_ty, tgt_ty) } @@ -302,7 +302,7 @@ fn bench(c: &mut Criterion) { let keypair = bitcoin::key::Keypair::new(&secp_ctx, &mut thread_rng()); let xpk = bitcoin::key::XOnlyPublicKey::from_keypair(&keypair); - let msg = bitcoin::secp256k1::Message::from_slice(&rand::random::<[u8; 32]>()).unwrap(); + let msg = bitcoin::secp256k1::Message::from_digest_slice(&rand::random::<[u8; 32]>()).unwrap(); let sig = secp_ctx.sign_schnorr(&msg, &keypair); let xpk_value = Value::u256_from_slice(&xpk.0.serialize()); let sig_value = Value::u512_from_slice(sig.as_ref()); diff --git a/src/bit_encoding/bitwriter.rs b/src/bit_encoding/bitwriter.rs index faae4f82..0a6a3be8 100644 --- a/src/bit_encoding/bitwriter.rs +++ b/src/bit_encoding/bitwriter.rs @@ -117,12 +117,13 @@ mod tests { use super::*; use crate::jet::Core; use crate::node::CoreConstructible; + use crate::types; use crate::ConstructNode; use std::sync::Arc; #[test] fn vec() { - let program = Arc::>::unit(); + let program = Arc::>::unit(&types::Context::new()); let _ = write_to_vec(|w| program.encode(w)); } diff --git a/src/bit_encoding/decode.rs b/src/bit_encoding/decode.rs index 8cd9cf96..d4999d0b 100644 --- a/src/bit_encoding/decode.rs +++ b/src/bit_encoding/decode.rs @@ -12,6 +12,7 @@ use crate::node::{ ConstructNode, CoreConstructible, DisconnectConstructible, JetConstructible, NoWitness, WitnessConstructible, }; +use crate::types; use crate::{BitIter, FailEntropy, Value}; use std::collections::HashSet; use std::sync::Arc; @@ -178,6 +179,7 @@ pub fn decode_expression, J: Jet>( return Err(Error::TooManyNodes(len)); } + let inference_context = types::Context::new(); let mut nodes = Vec::with_capacity(len); for _ in 0..len { let new_node = decode_node(bits, nodes.len())?; @@ -195,8 +197,8 @@ pub fn decode_expression, J: Jet>( } let new = match nodes[data.node.0] { - DecodeNode::Unit => Node(ArcNode::unit()), - DecodeNode::Iden => Node(ArcNode::iden()), + DecodeNode::Unit => Node(ArcNode::unit(&inference_context)), + DecodeNode::Iden => Node(ArcNode::iden(&inference_context)), DecodeNode::InjL(i) => Node(ArcNode::injl(converted[i].get()?)), DecodeNode::InjR(i) => Node(ArcNode::injr(converted[i].get()?)), DecodeNode::Take(i) => Node(ArcNode::take(converted[i].get()?)), @@ -222,16 +224,16 @@ pub fn decode_expression, J: Jet>( converted[i].get()?, &Some(Arc::clone(converted[j].get()?)), )?), - DecodeNode::Witness => Node(ArcNode::witness(NoWitness)), - DecodeNode::Fail(entropy) => Node(ArcNode::fail(entropy)), + DecodeNode::Witness => Node(ArcNode::witness(&inference_context, NoWitness)), + DecodeNode::Fail(entropy) => Node(ArcNode::fail(&inference_context, entropy)), DecodeNode::Hidden(cmr) => { if !hidden_set.insert(cmr) { return Err(Error::SharingNotMaximal); } Hidden(cmr) } - DecodeNode::Jet(j) => Node(ArcNode::jet(j)), - DecodeNode::Word(ref w) => Node(ArcNode::const_word(Arc::clone(w))), + DecodeNode::Jet(j) => Node(ArcNode::jet(&inference_context, j)), + DecodeNode::Word(ref w) => Node(ArcNode::const_word(&inference_context, Arc::clone(w))), }; converted.push(new); } diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index fec35c17..8aeccec2 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -9,7 +9,7 @@ use crate::node::{ self, Commit, CommitData, CommitNode, Converter, Inner, NoDisconnect, NoWitness, Node, Witness, WitnessData, }; -use crate::node::{Construct, ConstructData, Constructible}; +use crate::node::{Construct, ConstructData, Constructible as _, CoreConstructible as _}; use crate::types; use crate::types::arrow::{Arrow, FinalArrow}; use crate::{encode, Value, WitnessNode}; @@ -113,11 +113,12 @@ impl NamedCommitNode { witness: &HashMap, Arc>, disconnect: &HashMap, Arc>>, ) -> Arc> { - struct Populator<'a, J: Jet>( - &'a HashMap, Arc>, - &'a HashMap, Arc>>, - PhantomData, - ); + struct Populator<'a, J: Jet> { + witness_map: &'a HashMap, Arc>, + disconnect_map: &'a HashMap, Arc>>, + inference_context: types::Context, + phantom: PhantomData, + } impl<'a, J: Jet> Converter>, Witness> for Populator<'a, J> { type Error = (); @@ -133,7 +134,7 @@ impl NamedCommitNode { // Which nodes are pruned is not known when this code is executed. // If an unpruned node is unpopulated, then there will be an error // during the finalization. - Ok(self.0.get(name).cloned()) + Ok(self.witness_map.get(name).cloned()) } fn convert_disconnect( @@ -152,16 +153,17 @@ impl NamedCommitNode { // We keep the missing disconnected branches empty. // Like witness nodes (see above), disconnect nodes may be pruned later. // The finalization will detect missing branches and throw an error. - let maybe_commit = self.1.get(hole_name); - // FIXME: Recursive call of to_witness_node - // We cannot introduce a stack - // because we are implementing methods of the trait Converter - // which are used Marker::convert(). + let maybe_commit = self.disconnect_map.get(hole_name); + // FIXME: recursive call to convert + // We cannot introduce a stack because we are implementing the Converter + // trait and do not have access to the actual algorithm used for conversion + // in order to save its state. // // OTOH, if a user writes a program with so many disconnected expressions // that there is a stack overflow, it's his own fault :) - // This would fail in a fuzz test. - let witness = maybe_commit.map(|commit| commit.to_witness_node(self.0, self.1)); + // This may fail in a fuzz test. + let witness = maybe_commit + .map(|commit| commit.convert::(self).unwrap()); Ok(witness) } } @@ -179,12 +181,18 @@ impl NamedCommitNode { let inner = inner .map(|node| node.cached_data()) .map_witness(|maybe_value| maybe_value.clone()); - Ok(WitnessData::from_inner(inner).expect("types are already finalized")) + Ok(WitnessData::from_inner(&self.inference_context, inner) + .expect("types are already finalized")) } } - self.convert::(&mut Populator(witness, disconnect, PhantomData)) - .unwrap() + self.convert::(&mut Populator { + witness_map: witness, + disconnect_map: disconnect, + inference_context: types::Context::new(), + phantom: PhantomData, + }) + .unwrap() } /// Encode a Simplicity expression to bits without any witness data @@ -239,6 +247,7 @@ pub struct NamedConstructData { impl NamedConstructNode { /// Construct a named construct node from parts. pub fn new( + inference_context: &types::Context, name: Arc, position: Position, user_source_types: Arc<[types::Type]>, @@ -246,6 +255,7 @@ impl NamedConstructNode { inner: node::Inner, J, Arc, WitnessOrHole>, ) -> Result { let construct_data = ConstructData::from_inner( + inference_context, inner .as_ref() .map(|data| &data.cached_data().internal) @@ -289,6 +299,11 @@ impl NamedConstructNode { self.cached_data().internal.arrow() } + /// Accessor for the node's type inference context. + pub fn inference_context(&self) -> &types::Context { + self.cached_data().internal.inference_context() + } + /// Finalizes the types of the underlying [`crate::ConstructNode`]. pub fn finalize_types_main(&self) -> Result>, ErrorSet> { self.finalize_types_inner(true) @@ -380,17 +395,23 @@ impl NamedConstructNode { .map_disconnect(|_| &NoDisconnect) .copy_witness(); + let ctx = data.node.inference_context(); + if !self.for_main { // For non-`main` fragments, treat the ascriptions as normative, and apply them // before finalizing the type. let arrow = data.node.arrow(); for ty in data.node.cached_data().user_source_types.as_ref() { - if let Err(e) = arrow.source.unify(ty, "binding source type annotation") { + if let Err(e) = + ctx.unify(&arrow.source, ty, "binding source type annotation") + { self.errors.add(data.node.position(), e); } } for ty in data.node.cached_data().user_target_types.as_ref() { - if let Err(e) = arrow.target.unify(ty, "binding target type annotation") { + if let Err(e) = + ctx.unify(&arrow.target, ty, "binding target type annotation") + { self.errors.add(data.node.position(), e); } } @@ -407,19 +428,19 @@ impl NamedConstructNode { if self.for_main { // For `main`, only apply type ascriptions *after* inference has completely // determined the type. - let source_bound = - types::Bound::Complete(Arc::clone(&commit_data.arrow().source)); - let source_ty = types::Type::from(source_bound); + let source_ty = + types::Type::complete(ctx, Arc::clone(&commit_data.arrow().source)); for ty in data.node.cached_data().user_source_types.as_ref() { - if let Err(e) = source_ty.unify(ty, "binding source type annotation") { + if let Err(e) = ctx.unify(&source_ty, ty, "binding source type annotation") + { self.errors.add(data.node.position(), e); } } - let target_bound = - types::Bound::Complete(Arc::clone(&commit_data.arrow().target)); - let target_ty = types::Type::from(target_bound); + let target_ty = + types::Type::complete(ctx, Arc::clone(&commit_data.arrow().target)); for ty in data.node.cached_data().user_target_types.as_ref() { - if let Err(e) = target_ty.unify(ty, "binding target type annotation") { + if let Err(e) = ctx.unify(&target_ty, ty, "binding target type annotation") + { self.errors.add(data.node.position(), e); } } @@ -440,22 +461,23 @@ impl NamedConstructNode { }; if for_main { - let unit_ty = types::Type::unit(); + let ctx = self.inference_context(); + let unit_ty = types::Type::unit(ctx); if self.cached_data().user_source_types.is_empty() { - if let Err(e) = self - .arrow() - .source - .unify(&unit_ty, "setting root source to unit") - { + if let Err(e) = ctx.unify( + &self.arrow().source, + &unit_ty, + "setting root source to unit", + ) { finalizer.errors.add(self.position(), e); } } if self.cached_data().user_target_types.is_empty() { - if let Err(e) = self - .arrow() - .target - .unify(&unit_ty, "setting root source to unit") - { + if let Err(e) = ctx.unify( + &self.arrow().target, + &unit_ty, + "setting root target to unit", + ) { finalizer.errors.add(self.position(), e); } } diff --git a/src/human_encoding/parse/ast.rs b/src/human_encoding/parse/ast.rs index 241b606b..d34153e2 100644 --- a/src/human_encoding/parse/ast.rs +++ b/src/human_encoding/parse/ast.rs @@ -82,14 +82,25 @@ pub enum Type { impl Type { /// Convert to a Simplicity type - pub fn reify(self) -> types::Type { + pub fn reify(self, ctx: &types::Context) -> types::Type { match self { - Type::Name(s) => types::Type::free(s), - Type::One => types::Type::unit(), - Type::Two => types::Type::sum(types::Type::unit(), types::Type::unit()), - Type::Product(left, right) => types::Type::product(left.reify(), right.reify()), - Type::Sum(left, right) => types::Type::sum(left.reify(), right.reify()), - Type::TwoTwoN(n) => types::Type::two_two_n(n as usize), // cast OK as we are only using tiny numbers + Type::Name(s) => types::Type::free(ctx, s), + Type::One => types::Type::unit(ctx), + Type::Two => { + let unit_ty = types::Type::unit(ctx); + types::Type::sum(ctx, unit_ty.shallow_clone(), unit_ty) + } + Type::Product(left, right) => { + let left = left.reify(ctx); + let right = right.reify(ctx); + types::Type::product(ctx, left, right) + } + Type::Sum(left, right) => { + let left = left.reify(ctx); + let right = right.reify(ctx); + types::Type::sum(ctx, left, right) + } + Type::TwoTwoN(n) => types::Type::two_two_n(ctx, n as usize), // cast OK as we are only using tiny numbers } } } @@ -633,9 +644,7 @@ fn grammar() -> Grammar> { Error::BadWordLength { bit_length }, )); } - let ty = types::Type::two_two_n(bit_length.trailing_zeros() as usize) - .final_data() - .unwrap(); + let ty = types::Final::two_two_n(bit_length.trailing_zeros() as usize); // unwrap ok here since literally every sequence of bits is a valid // value for the given type let value = iter.read_value(&ty).unwrap(); diff --git a/src/human_encoding/parse/mod.rs b/src/human_encoding/parse/mod.rs index 74940cf5..21a4e7d5 100644 --- a/src/human_encoding/parse/mod.rs +++ b/src/human_encoding/parse/mod.rs @@ -7,7 +7,7 @@ mod ast; use crate::dag::{Dag, DagLike, InternalSharing}; use crate::jet::Jet; use crate::node; -use crate::types::Type; +use crate::types::{self, Type}; use std::collections::HashMap; use std::mem; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -181,6 +181,7 @@ pub fn parse( program: &str, ) -> Result, Arc>>, ErrorSet> { let mut errors = ErrorSet::new(); + let inference_context = types::Context::new(); // ** // Step 1: Read expressions into HashMap, checking for dupes and illegal names. // ** @@ -205,10 +206,10 @@ pub fn parse( } } if let Some(ty) = line.arrow.0 { - entry.add_source_type(ty.reify()); + entry.add_source_type(ty.reify(&inference_context)); } if let Some(ty) = line.arrow.1 { - entry.add_target_type(ty.reify()); + entry.add_target_type(ty.reify(&inference_context)); } } @@ -485,6 +486,7 @@ pub fn parse( .unwrap_or_else(|| Arc::from(namer.assign_name(inner.as_ref()).as_str())); let node = NamedConstructNode::new( + &inference_context, Arc::clone(&name), data.node.position, Arc::clone(&data.node.user_source_types), diff --git a/src/jet/elements/tests.rs b/src/jet/elements/tests.rs index b852edef..d1bbad9f 100644 --- a/src/jet/elements/tests.rs +++ b/src/jet/elements/tests.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use crate::jet::elements::{ElementsEnv, ElementsUtxo}; use crate::jet::Elements; use crate::node::{ConstructNode, JetConstructible}; +use crate::types; use crate::{BitMachine, Cmr, Value}; use elements::secp256k1_zkp::Tweak; use elements::taproot::ControlBlock; @@ -99,7 +100,7 @@ fn test_ffi_env() { BlockHash::all_zeros(), ); - let prog = Arc::>::jet(Elements::LockTime); + let prog = Arc::>::jet(&types::Context::new(), Elements::LockTime); assert_eq!( BitMachine::test_exec(prog, &env).expect("executing"), Value::u32(100), diff --git a/src/jet/mod.rs b/src/jet/mod.rs index 556492e3..8d78c202 100644 --- a/src/jet/mod.rs +++ b/src/jet/mod.rs @@ -93,18 +93,20 @@ pub trait Jet: mod tests { use crate::jet::Core; use crate::node::{ConstructNode, CoreConstructible, JetConstructible}; + use crate::types; use crate::{BitMachine, Value}; use std::sync::Arc; #[test] fn test_ffi_jet() { + let ctx = types::Context::new(); let two_words = Arc::>::comp( &Arc::>::pair( - &Arc::>::const_word(Value::u32(2)), - &Arc::>::const_word(Value::u32(16)), + &Arc::>::const_word(&ctx, Value::u32(2)), + &Arc::>::const_word(&ctx, Value::u32(16)), ) .unwrap(), - &Arc::>::jet(Core::Add32), + &Arc::>::jet(&ctx, Core::Add32), ) .unwrap(); assert_eq!( @@ -118,9 +120,10 @@ mod tests { #[test] fn test_simple() { + let ctx = types::Context::new(); let two_words = Arc::>::pair( - &Arc::>::const_word(Value::u32(2)), - &Arc::>::const_word(Value::u16(16)), + &Arc::>::const_word(&ctx, Value::u32(2)), + &Arc::>::const_word(&ctx, Value::u16(16)), ) .unwrap(); assert_eq!( diff --git a/src/jet/type_name.rs b/src/jet/type_name.rs index 3a2e49a7..092f8222 100644 --- a/src/jet/type_name.rs +++ b/src/jet/type_name.rs @@ -4,7 +4,7 @@ //! //! Source and target types of jet nodes need to be specified manually. -use crate::types::{Final, Type}; +use crate::types::{self, Final, Type}; use std::cmp; use std::sync::Arc; @@ -30,94 +30,68 @@ use std::sync::Arc; #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub struct TypeName(pub &'static [u8]); -trait TypeConstructible { - fn two_two_n(n: Option) -> Self; - fn sum(left: Self, right: Self) -> Self; - fn product(left: Self, right: Self) -> Self; -} - -impl TypeConstructible for Type { - fn two_two_n(n: Option) -> Self { - match n { - None => Type::unit(), - Some(m) => Type::two_two_n(m as usize), // cast safety: 32-bit arch or higher - } +impl TypeName { + /// Convert the type name into a type. + pub fn to_type(&self, ctx: &types::Context) -> Type { + Type::complete(ctx, self.to_final()) } - fn sum(left: Self, right: Self) -> Self { - Type::sum(left, right) - } + /// Convert the type name into a finalized type. + pub fn to_final(&self) -> Arc { + let mut stack = Vec::with_capacity(16); - fn product(left: Self, right: Self) -> Self { - Type::product(left, right) - } -} + for c in self.0.iter().rev() { + match c { + b'1' => stack.push(Final::unit()), + b'2' => stack.push(Final::two_two_n(0)), + b'c' => stack.push(Final::two_two_n(3)), + b's' => stack.push(Final::two_two_n(4)), + b'i' => stack.push(Final::two_two_n(5)), + b'l' => stack.push(Final::two_two_n(6)), + b'h' => stack.push(Final::two_two_n(8)), + b'+' | b'*' => { + let left = stack.pop().expect("Illegal type name syntax!"); + let right = stack.pop().expect("Illegal type name syntax!"); -impl TypeConstructible for Arc { - fn two_two_n(n: Option) -> Self { - match n { - None => Final::unit(), - Some(m) => Final::two_two_n(m as usize), // cast safety: 32-bit arch or higher + match c { + b'+' => stack.push(Final::sum(left, right)), + b'*' => stack.push(Final::product(left, right)), + _ => unreachable!(), + } + } + _ => panic!("Illegal type name syntax!"), + } } - } - - fn sum(left: Self, right: Self) -> Self { - Final::sum(left, right) - } - - fn product(left: Self, right: Self) -> Self { - Final::product(left, right) - } -} -struct BitWidth(usize); - -impl TypeConstructible for BitWidth { - fn two_two_n(n: Option) -> Self { - match n { - None => BitWidth(0), - Some(m) => BitWidth(usize::pow(2, m)), + if stack.len() == 1 { + stack.pop().unwrap() + } else { + panic!("Illegal type name syntax!") } } - fn sum(left: Self, right: Self) -> Self { - BitWidth(1 + cmp::max(left.0, right.0)) - } - - fn product(left: Self, right: Self) -> Self { - BitWidth(left.0 + right.0) - } -} - -impl TypeName { - // b'1' = 49 - // b'2' = 50 - // b'c' = 99 - // b's' = 115 - // b'i' = 105 - // b'l' = 108 - // b'h' = 104 - // b'+' = 43 - // b'*' = 42 - fn construct(&self) -> T { + /// Convert the type name into a type's bitwidth. + /// + /// This is more efficient than creating the type and computing its bit-width + pub fn to_bit_width(&self) -> usize { let mut stack = Vec::with_capacity(16); for c in self.0.iter().rev() { match c { - b'1' => stack.push(T::two_two_n(None)), - b'2' => stack.push(T::two_two_n(Some(0))), - b'c' => stack.push(T::two_two_n(Some(3))), - b's' => stack.push(T::two_two_n(Some(4))), - b'i' => stack.push(T::two_two_n(Some(5))), - b'l' => stack.push(T::two_two_n(Some(6))), - b'h' => stack.push(T::two_two_n(Some(8))), + b'1' => stack.push(0), + b'2' => stack.push(1), + b'c' => stack.push(8), + b's' => stack.push(16), + b'i' => stack.push(32), + b'l' => stack.push(64), + b'h' => stack.push(256), b'+' | b'*' => { let left = stack.pop().expect("Illegal type name syntax!"); let right = stack.pop().expect("Illegal type name syntax!"); match c { - b'+' => stack.push(T::sum(left, right)), - b'*' => stack.push(T::product(left, right)), + b'+' => stack.push(1 + cmp::max(left, right)), + b'*' => stack.push(left + right), _ => unreachable!(), } } @@ -131,21 +105,4 @@ impl TypeName { panic!("Illegal type name syntax!") } } - - /// Convert the type name into a type. - pub fn to_type(&self) -> Type { - self.construct() - } - - /// Convert the type name into a finalized type. - pub fn to_final(&self) -> Arc { - self.construct() - } - - /// Convert the type name into a type's bitwidth. - /// - /// This is more efficient than creating the type and computing its bit-width - pub fn to_bit_width(&self) -> usize { - self.construct::().0 - } } diff --git a/src/merkle/amr.rs b/src/merkle/amr.rs index 10bb3ab2..e6d349ae 100644 --- a/src/merkle/amr.rs +++ b/src/merkle/amr.rs @@ -291,11 +291,13 @@ mod tests { use crate::jet::Core; use crate::node::{ConstructNode, JetConstructible}; + use crate::types; use std::sync::Arc; #[test] fn fixed_amr() { - let node = Arc::>::jet(Core::Verify) + let ctx = types::Context::new(); + let node = Arc::>::jet(&ctx, Core::Verify) .finalize_types_non_program() .unwrap(); // Checked against C implementation diff --git a/src/merkle/cmr.rs b/src/merkle/cmr.rs index 8c8e421c..9093ac18 100644 --- a/src/merkle/cmr.rs +++ b/src/merkle/cmr.rs @@ -7,7 +7,7 @@ use crate::jet::Jet; use crate::node::{ CoreConstructible, DisconnectConstructible, JetConstructible, WitnessConstructible, }; -use crate::types::Error; +use crate::types::{self, Error}; use crate::{FailEntropy, Tmr, Value}; use hashes::sha256::Midstate; @@ -253,75 +253,142 @@ impl Cmr { ]; } -impl CoreConstructible for Cmr { - fn iden() -> Self { - Cmr::iden() +/// Wrapper around a CMR which allows it to be constructed with the +/// `*Constructible*` traits, allowing CMRs to be computed using the +/// same generic construction code that nodes are. +pub struct ConstructibleCmr { + pub cmr: Cmr, + pub inference_context: types::Context, +} + +impl CoreConstructible for ConstructibleCmr { + fn iden(inference_context: &types::Context) -> Self { + ConstructibleCmr { + cmr: Cmr::iden(), + inference_context: inference_context.shallow_clone(), + } } - fn unit() -> Self { - Cmr::unit() + fn unit(inference_context: &types::Context) -> Self { + ConstructibleCmr { + cmr: Cmr::unit(), + inference_context: inference_context.shallow_clone(), + } } fn injl(child: &Self) -> Self { - Cmr::injl(*child) + ConstructibleCmr { + cmr: Cmr::injl(child.cmr), + inference_context: child.inference_context.shallow_clone(), + } } fn injr(child: &Self) -> Self { - Cmr::injl(*child) + ConstructibleCmr { + cmr: Cmr::injl(child.cmr), + inference_context: child.inference_context.shallow_clone(), + } } fn take(child: &Self) -> Self { - Cmr::take(*child) + ConstructibleCmr { + cmr: Cmr::take(child.cmr), + inference_context: child.inference_context.shallow_clone(), + } } fn drop_(child: &Self) -> Self { - Cmr::drop(*child) + ConstructibleCmr { + cmr: Cmr::drop(child.cmr), + inference_context: child.inference_context.shallow_clone(), + } } fn comp(left: &Self, right: &Self) -> Result { - Ok(Cmr::comp(*left, *right)) + left.inference_context.check_eq(&right.inference_context)?; + Ok(ConstructibleCmr { + cmr: Cmr::comp(left.cmr, right.cmr), + inference_context: left.inference_context.shallow_clone(), + }) } fn case(left: &Self, right: &Self) -> Result { - Ok(Cmr::case(*left, *right)) + left.inference_context.check_eq(&right.inference_context)?; + Ok(ConstructibleCmr { + cmr: Cmr::case(left.cmr, right.cmr), + inference_context: left.inference_context.shallow_clone(), + }) } fn assertl(left: &Self, right: Cmr) -> Result { - Ok(Cmr::case(*left, right)) + Ok(ConstructibleCmr { + cmr: Cmr::case(left.cmr, right), + inference_context: left.inference_context.shallow_clone(), + }) } fn assertr(left: Cmr, right: &Self) -> Result { - Ok(Cmr::case(left, *right)) + Ok(ConstructibleCmr { + cmr: Cmr::case(left, right.cmr), + inference_context: right.inference_context.shallow_clone(), + }) } fn pair(left: &Self, right: &Self) -> Result { - Ok(Cmr::pair(*left, *right)) + left.inference_context.check_eq(&right.inference_context)?; + Ok(ConstructibleCmr { + cmr: Cmr::pair(left.cmr, right.cmr), + inference_context: left.inference_context.shallow_clone(), + }) } - fn fail(entropy: FailEntropy) -> Self { - Cmr::fail(entropy) + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self { + ConstructibleCmr { + cmr: Cmr::fail(entropy), + inference_context: inference_context.shallow_clone(), + } + } + + fn const_word(inference_context: &types::Context, word: Arc) -> Self { + ConstructibleCmr { + cmr: Cmr::const_word(&word), + inference_context: inference_context.shallow_clone(), + } } - fn const_word(word: Arc) -> Self { - Cmr::const_word(&word) + fn inference_context(&self) -> &types::Context { + &self.inference_context } } -impl DisconnectConstructible for Cmr { +impl DisconnectConstructible for ConstructibleCmr { + // Specifically with disconnect we don't check for consistency between the + // type inference context of the disconnected node, if any, and that of + // the left node. The idea is, from the point of view of (Constructible)Cmr, + // the right child of disconnect doesn't even exist. fn disconnect(left: &Self, _right: &X) -> Result { - Ok(Cmr::disconnect(*left)) + Ok(ConstructibleCmr { + cmr: Cmr::disconnect(left.cmr), + inference_context: left.inference_context.shallow_clone(), + }) } } -impl WitnessConstructible for Cmr { - fn witness(_witness: W) -> Self { - Cmr::witness() +impl WitnessConstructible for ConstructibleCmr { + fn witness(inference_context: &types::Context, _witness: W) -> Self { + ConstructibleCmr { + cmr: Cmr::witness(), + inference_context: inference_context.shallow_clone(), + } } } -impl JetConstructible for Cmr { - fn jet(jet: J) -> Self { - jet.cmr() +impl JetConstructible for ConstructibleCmr { + fn jet(inference_context: &types::Context, jet: J) -> Self { + ConstructibleCmr { + cmr: jet.cmr(), + inference_context: inference_context.shallow_clone(), + } } } @@ -337,7 +404,8 @@ mod tests { #[test] fn cmr_display_unit() { - let c = Arc::>::unit(); + let ctx = types::Context::new(); + let c = Arc::>::unit(&ctx); assert_eq!( c.cmr().to_string(), @@ -364,7 +432,8 @@ mod tests { #[test] fn bit_cmr() { - let unit = Arc::>::unit(); + let ctx = types::Context::new(); + let unit = Arc::>::unit(&ctx); let bit0 = Arc::>::injl(&unit); assert_eq!(bit0.cmr(), Cmr::BITS[0]); diff --git a/src/merkle/tmr.rs b/src/merkle/tmr.rs index f36d6268..1c468f69 100644 --- a/src/merkle/tmr.rs +++ b/src/merkle/tmr.rs @@ -257,9 +257,11 @@ impl Tmr { #[cfg(test)] mod tests { - use super::super::bip340_iv; use super::*; + use crate::merkle::bip340_iv; + use crate::types; + #[test] fn const_ivs() { assert_eq!( @@ -280,7 +282,7 @@ mod tests { #[allow(clippy::needless_range_loop)] fn const_powers_of_2() { let n = Tmr::POWERS_OF_TWO.len(); - let types = crate::types::Type::powers_of_two(n); + let types = crate::types::Type::powers_of_two(&types::Context::new(), n); for i in 0..n { assert_eq!(Some(Tmr::POWERS_OF_TWO[i]), types[i].tmr()); } diff --git a/src/node/commit.rs b/src/node/commit.rs index a2dc81c2..ab7ee71a 100644 --- a/src/node/commit.rs +++ b/src/node/commit.rs @@ -202,7 +202,10 @@ impl CommitNode { /// Convert a [`CommitNode`] back to a [`ConstructNode`] by redoing type inference pub fn unfinalize_types(&self) -> Result>, types::Error> { - struct UnfinalizeTypes(PhantomData); + struct UnfinalizeTypes { + inference_context: types::Context, + phantom: PhantomData, + } impl Converter, Construct> for UnfinalizeTypes { type Error = types::Error; @@ -232,11 +235,17 @@ impl CommitNode { .map(|node| node.arrow()) .map_disconnect(|maybe_node| maybe_node.as_ref().map(|node| node.arrow())); let inner = inner.disconnect_as_ref(); // lol sigh rust - Ok(ConstructData::new(Arrow::from_inner(inner)?)) + Ok(ConstructData::new(Arrow::from_inner( + &self.inference_context, + inner, + )?)) } } - self.convert::>, _, _>(&mut UnfinalizeTypes(PhantomData)) + self.convert::>, _, _>(&mut UnfinalizeTypes { + inference_context: types::Context::new(), + phantom: PhantomData, + }) } /// Decode a Simplicity program from bits, without witness data. diff --git a/src/node/construct.rs b/src/node/construct.rs index 29999c36..2b57de20 100644 --- a/src/node/construct.rs +++ b/src/node/construct.rs @@ -52,13 +52,18 @@ impl ConstructNode { /// Sets the source and target type of the node to unit pub fn set_arrow_to_program(&self) -> Result<(), types::Error> { - let unit_ty = types::Type::unit(); - self.arrow() - .source - .unify(&unit_ty, "setting root source to unit")?; - self.arrow() - .target - .unify(&unit_ty, "setting root target to unit")?; + let ctx = self.data.inference_context(); + let unit_ty = types::Type::unit(ctx); + ctx.unify( + &self.arrow().source, + &unit_ty, + "setting root source to unit", + )?; + ctx.unify( + &self.arrow().target, + &unit_ty, + "setting root target to unit", + )?; Ok(()) } @@ -165,16 +170,16 @@ impl ConstructData { } impl CoreConstructible for ConstructData { - fn iden() -> Self { + fn iden(inference_context: &types::Context) -> Self { ConstructData { - arrow: Arrow::iden(), + arrow: Arrow::iden(inference_context), phantom: PhantomData, } } - fn unit() -> Self { + fn unit(inference_context: &types::Context) -> Self { ConstructData { - arrow: Arrow::unit(), + arrow: Arrow::unit(inference_context), phantom: PhantomData, } } @@ -242,19 +247,23 @@ impl CoreConstructible for ConstructData { }) } - fn fail(entropy: FailEntropy) -> Self { + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self { ConstructData { - arrow: Arrow::fail(entropy), + arrow: Arrow::fail(inference_context, entropy), phantom: PhantomData, } } - fn const_word(word: Arc) -> Self { + fn const_word(inference_context: &types::Context, word: Arc) -> Self { ConstructData { - arrow: Arrow::const_word(word), + arrow: Arrow::const_word(inference_context, word), phantom: PhantomData, } } + + fn inference_context(&self) -> &types::Context { + self.arrow.inference_context() + } } impl DisconnectConstructible>>> for ConstructData { @@ -271,18 +280,18 @@ impl DisconnectConstructible>>> for Construc } impl WitnessConstructible for ConstructData { - fn witness(witness: NoWitness) -> Self { + fn witness(inference_context: &types::Context, witness: NoWitness) -> Self { ConstructData { - arrow: Arrow::witness(witness), + arrow: Arrow::witness(inference_context, witness), phantom: PhantomData, } } } impl JetConstructible for ConstructData { - fn jet(jet: J) -> Self { + fn jet(inference_context: &types::Context, jet: J) -> Self { ConstructData { - arrow: Arrow::jet(jet), + arrow: Arrow::jet(inference_context, jet), phantom: PhantomData, } } @@ -295,7 +304,8 @@ mod tests { #[test] fn occurs_check_error() { - let iden = Arc::>::iden(); + let ctx = types::Context::new(); + let iden = Arc::>::iden(&ctx); let node = Arc::>::disconnect(&iden, &Some(Arc::clone(&iden))).unwrap(); assert!(matches!( @@ -306,8 +316,9 @@ mod tests { #[test] fn occurs_check_2() { + let ctx = types::Context::new(); // A more complicated occurs-check test that caused a deadlock in the past. - let iden = Arc::>::iden(); + let iden = Arc::>::iden(&ctx); let injr = Arc::>::injr(&iden); let pair = Arc::>::pair(&injr, &iden).unwrap(); let drop = Arc::>::drop_(&pair); @@ -326,8 +337,9 @@ mod tests { #[test] fn occurs_check_3() { + let ctx = types::Context::new(); // A similar example that caused a slightly different deadlock in the past. - let wit = Arc::>::witness(NoWitness); + let wit = Arc::>::witness(&ctx, NoWitness); let drop = Arc::>::drop_(&wit); let comp1 = Arc::>::comp(&drop, &drop).unwrap(); @@ -353,7 +365,8 @@ mod tests { #[test] fn type_check_error() { - let unit = Arc::>::unit(); + let ctx = types::Context::new(); + let unit = Arc::>::unit(&ctx); let case = Arc::>::case(&unit, &unit).unwrap(); assert!(matches!( @@ -364,26 +377,30 @@ mod tests { #[test] fn scribe() { - let unit = Arc::>::unit(); + // Ok to use same type inference context for all the below tests, + // since everything has concrete types and anyway we only care + // about CMRs, for which type inference is irrelevant. + let ctx = types::Context::new(); + let unit = Arc::>::unit(&ctx); let bit0 = Arc::>::injl(&unit); let bit1 = Arc::>::injr(&unit); let bits01 = Arc::>::pair(&bit0, &bit1).unwrap(); assert_eq!( unit.cmr(), - Arc::>::scribe(&Value::Unit).cmr() + Arc::>::scribe(&ctx, &Value::Unit).cmr() ); assert_eq!( bit0.cmr(), - Arc::>::scribe(&Value::u1(0)).cmr() + Arc::>::scribe(&ctx, &Value::u1(0)).cmr() ); assert_eq!( bit1.cmr(), - Arc::>::scribe(&Value::u1(1)).cmr() + Arc::>::scribe(&ctx, &Value::u1(1)).cmr() ); assert_eq!( bits01.cmr(), - Arc::>::scribe(&Value::u2(1)).cmr() + Arc::>::scribe(&ctx, &Value::u2(1)).cmr() ); } } diff --git a/src/node/mod.rs b/src/node/mod.rs index c0b7b943..f057c1a6 100644 --- a/src/node/mod.rs +++ b/src/node/mod.rs @@ -130,10 +130,13 @@ pub trait Constructible: + CoreConstructible + Sized { - fn from_inner(inner: Inner<&Self, J, &X, W>) -> Result { + fn from_inner( + inference_context: &types::Context, + inner: Inner<&Self, J, &X, W>, + ) -> Result { match inner { - Inner::Iden => Ok(Self::iden()), - Inner::Unit => Ok(Self::unit()), + Inner::Iden => Ok(Self::iden(inference_context)), + Inner::Unit => Ok(Self::unit(inference_context)), Inner::InjL(child) => Ok(Self::injl(child)), Inner::InjR(child) => Ok(Self::injr(child)), Inner::Take(child) => Ok(Self::take(child)), @@ -144,10 +147,10 @@ pub trait Constructible: Inner::AssertR(l_cmr, right) => Self::assertr(l_cmr, right), Inner::Pair(left, right) => Self::pair(left, right), Inner::Disconnect(left, right) => Self::disconnect(left, right), - Inner::Fail(entropy) => Ok(Self::fail(entropy)), - Inner::Word(ref w) => Ok(Self::const_word(Arc::clone(w))), - Inner::Jet(j) => Ok(Self::jet(j)), - Inner::Witness(w) => Ok(Self::witness(w)), + Inner::Fail(entropy) => Ok(Self::fail(inference_context, entropy)), + Inner::Word(ref w) => Ok(Self::const_word(inference_context, Arc::clone(w))), + Inner::Jet(j) => Ok(Self::jet(inference_context, j)), + Inner::Witness(w) => Ok(Self::witness(inference_context, w)), } } } @@ -162,8 +165,8 @@ impl Constructible for T where } pub trait CoreConstructible: Sized { - fn iden() -> Self; - fn unit() -> Self; + fn iden(inference_context: &types::Context) -> Self; + fn unit(inference_context: &types::Context) -> Self; fn injl(child: &Self) -> Self; fn injr(child: &Self) -> Self; fn take(child: &Self) -> Self; @@ -173,17 +176,20 @@ pub trait CoreConstructible: Sized { fn assertl(left: &Self, right: Cmr) -> Result; fn assertr(left: Cmr, right: &Self) -> Result; fn pair(left: &Self, right: &Self) -> Result; - fn fail(entropy: FailEntropy) -> Self; - fn const_word(word: Arc) -> Self; + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self; + fn const_word(inference_context: &types::Context, word: Arc) -> Self; + + /// Accessor for the type inference context used to create the object. + fn inference_context(&self) -> &types::Context; /// Create a DAG that takes any input and returns `value` as constant output. /// /// _Overall type: A → B where value: B_ - fn scribe(value: &Value) -> Self { + fn scribe(inference_context: &types::Context, value: &Value) -> Self { let mut stack = vec![]; for data in value.post_order_iter::() { match data.node { - Value::Unit => stack.push(Self::unit()), + Value::Unit => stack.push(Self::unit(inference_context)), Value::SumL(..) => { let child = stack.pop().unwrap(); stack.push(Self::injl(&child)); @@ -208,16 +214,16 @@ pub trait CoreConstructible: Sized { /// Create a DAG that takes any input and returns bit `0` as constant output. /// /// _Overall type: A → 2_ - fn bit_false() -> Self { - let unit = Self::unit(); + fn bit_false(inference_context: &types::Context) -> Self { + let unit = Self::unit(inference_context); Self::injl(&unit) } /// Create a DAG that takes any input and returns bit `1` as constant output. /// /// _Overall type: A → 2_ - fn bit_true() -> Self { - let unit = Self::unit(); + fn bit_true(inference_context: &types::Context) -> Self { + let unit = Self::unit(inference_context); Self::injr(&unit) } @@ -241,7 +247,7 @@ pub trait CoreConstructible: Sized { /// /// _Type inference will fail if children are not of the correct type._ fn assert(child: &Self, hash: Cmr) -> Result { - let unit = Self::unit(); + let unit = Self::unit(child.inference_context()); let pair_child_unit = Self::pair(child, &unit)?; let assertr_hidden_unit = Self::assertr(hash, &unit)?; @@ -255,10 +261,10 @@ pub trait CoreConstructible: Sized { /// _Type inference will fail if children are not of the correct type._ #[allow(clippy::should_implement_trait)] fn not(child: &Self) -> Result { - let unit = Self::unit(); + let unit = Self::unit(child.inference_context()); let pair_child_unit = Self::pair(child, &unit)?; - let bit_true = Self::bit_true(); - let bit_false = Self::bit_false(); + let bit_true = Self::bit_true(child.inference_context()); + let bit_false = Self::bit_false(child.inference_context()); let case_true_false = Self::case(&bit_true, &bit_false)?; Self::comp(&pair_child_unit, &case_true_false) @@ -270,9 +276,11 @@ pub trait CoreConstructible: Sized { /// /// _Type inference will fail if children are not of the correct type._ fn and(left: &Self, right: &Self) -> Result { - let iden = Self::iden(); + left.inference_context() + .check_eq(right.inference_context())?; + let iden = Self::iden(left.inference_context()); let pair_left_iden = Self::pair(left, &iden)?; - let bit_false = Self::bit_false(); + let bit_false = Self::bit_false(left.inference_context()); let drop_right = Self::drop_(right); let case_false_right = Self::case(&bit_false, &drop_right)?; @@ -285,10 +293,12 @@ pub trait CoreConstructible: Sized { /// /// _Type inference will fail if children are not of the correct type._ fn or(left: &Self, right: &Self) -> Result { - let iden = Self::iden(); + left.inference_context() + .check_eq(right.inference_context())?; + let iden = Self::iden(left.inference_context()); let pair_left_iden = Self::pair(left, &iden)?; let drop_right = Self::drop_(right); - let bit_true = Self::bit_true(); + let bit_true = Self::bit_true(left.inference_context()); let case_right_true = Self::case(&drop_right, &bit_true)?; Self::comp(&pair_left_iden, &case_right_true) @@ -300,11 +310,11 @@ pub trait DisconnectConstructible: Sized { } pub trait JetConstructible: Sized { - fn jet(jet: J) -> Self; + fn jet(inference_context: &types::Context, jet: J) -> Self; } pub trait WitnessConstructible: Sized { - fn witness(witness: W) -> Self; + fn witness(inference_context: &types::Context, witness: W) -> Self; } /// A node in a Simplicity expression. @@ -373,18 +383,18 @@ where N: Marker, N::CachedData: CoreConstructible, { - fn iden() -> Self { + fn iden(inference_context: &types::Context) -> Self { Arc::new(Node { cmr: Cmr::iden(), - data: N::CachedData::iden(), + data: N::CachedData::iden(inference_context), inner: Inner::Iden, }) } - fn unit() -> Self { + fn unit(inference_context: &types::Context) -> Self { Arc::new(Node { cmr: Cmr::unit(), - data: N::CachedData::unit(), + data: N::CachedData::unit(inference_context), inner: Inner::Unit, }) } @@ -461,21 +471,25 @@ where })) } - fn fail(entropy: FailEntropy) -> Self { + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self { Arc::new(Node { cmr: Cmr::fail(entropy), - data: N::CachedData::fail(entropy), + data: N::CachedData::fail(inference_context, entropy), inner: Inner::Fail(entropy), }) } - fn const_word(value: Arc) -> Self { + fn const_word(inference_context: &types::Context, value: Arc) -> Self { Arc::new(Node { cmr: Cmr::const_word(&value), - data: N::CachedData::const_word(Arc::clone(&value)), + data: N::CachedData::const_word(inference_context, Arc::clone(&value)), inner: Inner::Word(value), }) } + + fn inference_context(&self) -> &types::Context { + self.data.inference_context() + } } impl DisconnectConstructible for Arc> @@ -497,10 +511,10 @@ where N: Marker, N::CachedData: WitnessConstructible, { - fn witness(value: N::Witness) -> Self { + fn witness(inference_context: &types::Context, value: N::Witness) -> Self { Arc::new(Node { cmr: Cmr::witness(), - data: N::CachedData::witness(value.clone()), + data: N::CachedData::witness(inference_context, value.clone()), inner: Inner::Witness(value), }) } @@ -511,10 +525,10 @@ where N: Marker, N::CachedData: JetConstructible, { - fn jet(jet: N::Jet) -> Self { + fn jet(inference_context: &types::Context, jet: N::Jet) -> Self { Arc::new(Node { cmr: Cmr::jet(jet), - data: N::CachedData::jet(jet), + data: N::CachedData::jet(inference_context, jet), inner: Inner::Jet(jet), }) } diff --git a/src/node/redeem.rs b/src/node/redeem.rs index 0679ec22..11e5eb49 100644 --- a/src/node/redeem.rs +++ b/src/node/redeem.rs @@ -223,7 +223,10 @@ impl RedeemNode { /// Convert a [`RedeemNode`] back into a [`WitnessNode`] /// by loosening the finalized types, witness data and disconnected branches. pub fn to_witness_node(&self) -> Arc> { - struct ToWitness(PhantomData); + struct ToWitness { + inference_context: types::Context, + phantom: PhantomData, + } impl Converter, Witness> for ToWitness { type Error = (); @@ -258,12 +261,16 @@ impl RedeemNode { let inner = inner .map(|node| node.cached_data()) .map_witness(|maybe_value| maybe_value.clone()); - Ok(WitnessData::from_inner(inner).expect("types are already finalized")) + Ok(WitnessData::from_inner(&self.inference_context, inner) + .expect("types were already finalized")) } } - self.convert::(&mut ToWitness(PhantomData)) - .unwrap() + self.convert::(&mut ToWitness { + inference_context: types::Context::new(), + phantom: PhantomData, + }) + .unwrap() } /// Decode a Simplicity program from bits, including the witness data. @@ -283,7 +290,8 @@ impl RedeemNode { data: &PostOrderIterItem<&ConstructNode>, _: &NoWitness, ) -> Result, Self::Error> { - let target_ty = data.node.data.arrow().target.finalize()?; + let arrow = data.node.data.arrow(); + let target_ty = arrow.target.finalize()?; self.bits.read_value(&target_ty).map_err(Error::from) } diff --git a/src/node/witness.rs b/src/node/witness.rs index 41cd9086..b9748c58 100644 --- a/src/node/witness.rs +++ b/src/node/witness.rs @@ -74,7 +74,10 @@ impl WitnessNode { } pub fn prune_and_retype(&self) -> Arc { - struct Retyper(PhantomData); + struct Retyper { + inference_context: types::Context, + phantom: PhantomData, + } impl Converter, Witness> for Retyper { type Error = types::Error; @@ -131,7 +134,8 @@ impl WitnessNode { .map(|node| node.cached_data()) .map_witness(Option::>::clone); // This next line does the actual retyping. - let mut retyped = WitnessData::from_inner(converted_inner)?; + let mut retyped = + WitnessData::from_inner(&self.inference_context, converted_inner)?; // Sometimes we set the prune bit on nodes without setting that // of their children; in this case the prune bit inferred from // `converted_inner` will be incorrect. @@ -144,8 +148,11 @@ impl WitnessNode { // FIXME after running the `ReTyper` we should run a `WitnessShrinker` which // shrinks the witness data in case the ReTyper shrank its types. - self.convert::(&mut Retyper(PhantomData)) - .expect("type inference won't fail if it succeeded before") + self.convert::(&mut Retyper { + inference_context: types::Context::new(), + phantom: PhantomData, + }) + .expect("type inference won't fail if it succeeded before") } pub fn finalize(&self) -> Result>, Error> { @@ -198,15 +205,18 @@ impl WitnessNode { // 1. First, prune everything that we can let pruned_self = self.prune_and_retype(); // 2. Then, set the root arrow to 1->1 - let unit_ty = types::Type::unit(); - pruned_self - .arrow() - .source - .unify(&unit_ty, "setting root source to unit")?; - pruned_self - .arrow() - .target - .unify(&unit_ty, "setting root source to unit")?; + let ctx = pruned_self.inference_context(); + let unit_ty = types::Type::unit(ctx); + ctx.unify( + &pruned_self.arrow().source, + &unit_ty, + "setting root source to unit", + )?; + ctx.unify( + &pruned_self.arrow().target, + &unit_ty, + "setting root target to unit", + )?; // 3. Then attempt to convert the whole program to a RedeemNode. // Despite all of the above this can still fail due to the @@ -228,17 +238,17 @@ pub struct WitnessData { } impl CoreConstructible for WitnessData { - fn iden() -> Self { + fn iden(inference_context: &types::Context) -> Self { WitnessData { - arrow: Arrow::iden(), + arrow: Arrow::iden(inference_context), must_prune: false, phantom: PhantomData, } } - fn unit() -> Self { + fn unit(inference_context: &types::Context) -> Self { WitnessData { - arrow: Arrow::unit(), + arrow: Arrow::unit(inference_context), must_prune: false, phantom: PhantomData, } @@ -319,22 +329,26 @@ impl CoreConstructible for WitnessData { }) } - fn fail(entropy: FailEntropy) -> Self { + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self { // Fail nodes always get pruned. WitnessData { - arrow: Arrow::fail(entropy), + arrow: Arrow::fail(inference_context, entropy), must_prune: true, phantom: PhantomData, } } - fn const_word(word: Arc) -> Self { + fn const_word(inference_context: &types::Context, word: Arc) -> Self { WitnessData { - arrow: Arrow::const_word(word), + arrow: Arrow::const_word(inference_context, word), must_prune: false, phantom: PhantomData, } } + + fn inference_context(&self) -> &types::Context { + self.arrow.inference_context() + } } impl DisconnectConstructible>>> for WitnessData { @@ -349,9 +363,9 @@ impl DisconnectConstructible>>> for WitnessDat } impl WitnessConstructible>> for WitnessData { - fn witness(witness: Option>) -> Self { + fn witness(inference_context: &types::Context, witness: Option>) -> Self { WitnessData { - arrow: Arrow::witness(NoWitness), + arrow: Arrow::witness(inference_context, NoWitness), must_prune: witness.is_none(), phantom: PhantomData, } @@ -359,9 +373,9 @@ impl WitnessConstructible>> for WitnessData { } impl JetConstructible for WitnessData { - fn jet(jet: J) -> Self { + fn jet(inference_context: &types::Context, jet: J) -> Self { WitnessData { - arrow: Arrow::jet(jet), + arrow: Arrow::jet(inference_context, jet), must_prune: false, phantom: PhantomData, } diff --git a/src/policy/ast.rs b/src/policy/ast.rs index 5d55ec29..94360646 100644 --- a/src/policy/ast.rs +++ b/src/policy/ast.rs @@ -17,6 +17,7 @@ use crate::node::{ ConstructNode, CoreConstructible, JetConstructible, NoWitness, WitnessConstructible, }; use crate::policy::serialize::{self, AssemblyConstructible}; +use crate::types; use crate::{Cmr, CommitNode, FailEntropy}; use crate::{SimplicityKey, ToXOnlyPubkey, Translator}; @@ -58,7 +59,7 @@ pub enum Policy { impl Policy { /// Serializes the policy as a Simplicity fragment, with all witness nodes unpopulated. - fn serialize_no_witness(&self) -> Option + fn serialize_no_witness(&self, inference_context: &types::Context) -> Option where N: CoreConstructible + JetConstructible @@ -66,54 +67,62 @@ impl Policy { + AssemblyConstructible, { match *self { - Policy::Unsatisfiable(entropy) => Some(serialize::unsatisfiable(entropy)), - Policy::Trivial => Some(serialize::trivial()), - Policy::After(n) => Some(serialize::after(n)), - Policy::Older(n) => Some(serialize::older(n)), - Policy::Key(ref key) => Some(serialize::key(key, NoWitness)), - Policy::Sha256(ref hash) => Some(serialize::sha256::(hash, NoWitness)), + Policy::Unsatisfiable(entropy) => { + Some(serialize::unsatisfiable(inference_context, entropy)) + } + Policy::Trivial => Some(serialize::trivial(inference_context)), + Policy::After(n) => Some(serialize::after(inference_context, n)), + Policy::Older(n) => Some(serialize::older(inference_context, n)), + Policy::Key(ref key) => Some(serialize::key(inference_context, key, NoWitness)), + Policy::Sha256(ref hash) => Some(serialize::sha256::( + inference_context, + hash, + NoWitness, + )), Policy::And { ref left, ref right, } => { - let left = left.serialize_no_witness()?; - let right = right.serialize_no_witness()?; + let left = left.serialize_no_witness(inference_context)?; + let right = right.serialize_no_witness(inference_context)?; Some(serialize::and(&left, &right)) } Policy::Or { ref left, ref right, } => { - let left = left.serialize_no_witness()?; - let right = right.serialize_no_witness()?; + let left = left.serialize_no_witness(inference_context)?; + let right = right.serialize_no_witness(inference_context)?; Some(serialize::or(&left, &right, NoWitness)) } Policy::Threshold(k, ref subs) => { let k = u32::try_from(k).expect("can have k at most 2^32 in a threshold"); let subs = subs .iter() - .map(Self::serialize_no_witness) + .map(|sub| sub.serialize_no_witness(inference_context)) .collect::>>()?; let wits = iter::repeat(NoWitness) .take(subs.len()) .collect::>(); Some(serialize::threshold(k, &subs, &wits)) } - Policy::Assembly(cmr) => N::assembly(cmr), + Policy::Assembly(cmr) => N::assembly(inference_context, cmr), } } /// Return the program commitment of the policy. pub fn commit(&self) -> Option>> { - let construct: Arc> = self.serialize_no_witness()?; + let construct: Arc> = + self.serialize_no_witness(&types::Context::new())?; let commit = construct.finalize_types().expect("policy has sound types"); Some(commit) } /// Return the CMR of the policy. pub fn cmr(&self) -> Cmr { - self.serialize_no_witness() + self.serialize_no_witness::(&types::Context::new()) .expect("CMR is defined for asm fragment") + .cmr } } diff --git a/src/policy/satisfy.rs b/src/policy/satisfy.rs index f8888a29..94d43be2 100644 --- a/src/policy/satisfy.rs +++ b/src/policy/satisfy.rs @@ -4,6 +4,7 @@ use crate::analysis::Cost; use crate::jet::Elements; use crate::node::{RedeemNode, WitnessNode}; use crate::policy::ToXOnlyPubkey; +use crate::types; use crate::{Cmr, Error, Policy, Value}; use elements::bitcoin; @@ -93,19 +94,22 @@ impl Satisfier for elements::LockTime { impl Policy { fn satisfy_internal>( &self, + inference_context: &types::Context, satisfier: &S, ) -> Result>, Error> { let node = match *self { - Policy::Unsatisfiable(entropy) => super::serialize::unsatisfiable(entropy), - Policy::Trivial => super::serialize::trivial(), + Policy::Unsatisfiable(entropy) => { + super::serialize::unsatisfiable(inference_context, entropy) + } + Policy::Trivial => super::serialize::trivial(inference_context), Policy::Key(ref key) => { let sig_wit = satisfier .lookup_tap_leaf_script_sig(key, &TapLeafHash::all_zeros()) .map(|sig| Value::u512_from_slice(sig.sig.as_ref())); - super::serialize::key(key, sig_wit) + super::serialize::key(inference_context, key, sig_wit) } Policy::After(n) => { - let node = super::serialize::after::>(n); + let node = super::serialize::after::>(inference_context, n); let height = Height::from_consensus(n).expect("timelock is valid"); if satisfier.check_after(elements::LockTime::Blocks(height)) { node @@ -114,7 +118,7 @@ impl Policy { } } Policy::Older(n) => { - let node = super::serialize::older::>(n); + let node = super::serialize::older::>(inference_context, n); if satisfier.check_older(elements::Sequence((n).into())) { node } else { @@ -125,22 +129,22 @@ impl Policy { let preimage_wit = satisfier .lookup_sha256(hash) .map(|preimage| Value::u256_from_slice(preimage.as_ref())); - super::serialize::sha256::(hash, preimage_wit) + super::serialize::sha256::(inference_context, hash, preimage_wit) } Policy::And { ref left, ref right, } => { - let left = left.satisfy_internal(satisfier)?; - let right = right.satisfy_internal(satisfier)?; + let left = left.satisfy_internal(inference_context, satisfier)?; + let right = right.satisfy_internal(inference_context, satisfier)?; super::serialize::and(&left, &right) } Policy::Or { ref left, ref right, } => { - let left = left.satisfy_internal(satisfier)?; - let right = right.satisfy_internal(satisfier)?; + let left = left.satisfy_internal(inference_context, satisfier)?; + let right = right.satisfy_internal(inference_context, satisfier)?; let take_right = match (left.must_prune(), right.must_prune()) { (false, false) => { @@ -165,7 +169,7 @@ impl Policy { Policy::Threshold(k, ref subs) => { let nodes: Result>>, Error> = subs .iter() - .map(|sub| sub.satisfy_internal(satisfier)) + .map(|sub| sub.satisfy_internal(inference_context, satisfier)) .collect(); let mut nodes = nodes?; let mut costs = vec![Cost::CONSENSUS_MAX; subs.len()]; @@ -215,7 +219,7 @@ impl Policy { &self, satisfier: &S, ) -> Result>, Error> { - let witnode = self.satisfy_internal(satisfier)?; + let witnode = self.satisfy_internal(&types::Context::new(), satisfier)?; if witnode.must_prune() { Err(Error::IncompleteFinalization) } else { @@ -626,17 +630,18 @@ mod tests { #[test] fn satisfy_asm() { + let ctx = types::Context::new(); let env = ElementsEnv::dummy(); let mut satisfier = get_satisfier(&env); let mut assert_branch = |witness0: Arc, witness1: Arc| { let asm_program = serialize::verify_bexp( &Arc::>::pair( - &Arc::>::witness(Some(witness0.clone())), - &Arc::>::witness(Some(witness1.clone())), + &Arc::>::witness(&ctx, Some(witness0.clone())), + &Arc::>::witness(&ctx, Some(witness1.clone())), ) .expect("sound types"), - &Arc::>::jet(Elements::Eq8), + &Arc::>::jet(&ctx, Elements::Eq8), ); let cmr = asm_program.cmr(); satisfier.assembly.insert(cmr, asm_program); diff --git a/src/policy/serialize.rs b/src/policy/serialize.rs index 2322fbf0..2c586657 100644 --- a/src/policy/serialize.rs +++ b/src/policy/serialize.rs @@ -3,7 +3,9 @@ //! Serialization of Policy as Simplicity use crate::jet::{Elements, Jet}; +use crate::merkle::cmr::ConstructibleCmr; use crate::node::{CoreConstructible, JetConstructible, WitnessConstructible}; +use crate::types; use crate::{Cmr, ConstructNode, ToXOnlyPubkey}; use crate::{FailEntropy, Value}; @@ -14,70 +16,73 @@ use std::sync::Arc; pub trait AssemblyConstructible: Sized { /// Construct the assembly fragment with the given CMR. /// - /// The construction fails if the CMR alone is not enough information to construct the type. - fn assembly(cmr: Cmr) -> Option; + /// The construction fails if the CMR alone is not enough information to construct the object. + fn assembly(inference_context: &types::Context, cmr: Cmr) -> Option; } -impl AssemblyConstructible for Cmr { - fn assembly(cmr: Cmr) -> Option { - Some(cmr) +impl AssemblyConstructible for ConstructibleCmr { + fn assembly(inference_context: &types::Context, cmr: Cmr) -> Option { + Some(ConstructibleCmr { + cmr, + inference_context: inference_context.shallow_clone(), + }) } } impl AssemblyConstructible for Arc> { - fn assembly(_cmr: Cmr) -> Option { + fn assembly(_: &types::Context, _cmr: Cmr) -> Option { None } } -pub fn unsatisfiable(entropy: FailEntropy) -> N +pub fn unsatisfiable(inference_context: &types::Context, entropy: FailEntropy) -> N where N: CoreConstructible, { - N::fail(entropy) + N::fail(inference_context, entropy) } -pub fn trivial() -> N +pub fn trivial(inference_context: &types::Context) -> N where N: CoreConstructible, { - N::unit() + N::unit(inference_context) } -pub fn key(key: &Pk, witness: W) -> N +pub fn key(inference_context: &types::Context, key: &Pk, witness: W) -> N where Pk: ToXOnlyPubkey, N: CoreConstructible + JetConstructible + WitnessConstructible, { let key_value = Value::u256_from_slice(&key.to_x_only_pubkey().serialize()); - let const_key = N::const_word(key_value); - let sighash_all = N::jet(Elements::SigAllHash); + let const_key = N::const_word(inference_context, key_value); + let sighash_all = N::jet(inference_context, Elements::SigAllHash); let pair_key_msg = N::pair(&const_key, &sighash_all).expect("consistent types"); - let witness = N::witness(witness); + let witness = N::witness(inference_context, witness); let pair_key_msg_sig = N::pair(&pair_key_msg, &witness).expect("consistent types"); - let bip_0340_verify = N::jet(Elements::Bip0340Verify); + let bip_0340_verify = N::jet(inference_context, Elements::Bip0340Verify); N::comp(&pair_key_msg_sig, &bip_0340_verify).expect("consistent types") } -pub fn after(n: u32) -> N +pub fn after(inference_context: &types::Context, n: u32) -> N where N: CoreConstructible + JetConstructible, { let n_value = Value::u32(n); - let const_n = N::const_word(n_value); - let check_lock_height = N::jet(Elements::CheckLockHeight); + let const_n = N::const_word(inference_context, n_value); + let check_lock_height = N::jet(inference_context, Elements::CheckLockHeight); N::comp(&const_n, &check_lock_height).expect("consistent types") } -pub fn older(n: u16) -> N +pub fn older(inference_context: &types::Context, n: u16) -> N where N: CoreConstructible + JetConstructible, { let n_value = Value::u16(n); - let const_n = N::const_word(n_value); - let check_lock_distance = N::jet(Elements::CheckLockDistance); + let const_n = N::const_word(inference_context, n_value); + let check_lock_distance = N::jet(inference_context, Elements::CheckLockDistance); N::comp(&const_n, &check_lock_distance).expect("consistent types") } @@ -86,11 +91,11 @@ pub fn compute_sha256(witness256: &N) -> N where N: CoreConstructible + JetConstructible, { - let ctx = N::jet(Elements::Sha256Ctx8Init); + let ctx = N::jet(witness256.inference_context(), Elements::Sha256Ctx8Init); let pair_ctx_witness = N::pair(&ctx, witness256).expect("consistent types"); - let add256 = N::jet(Elements::Sha256Ctx8Add32); + let add256 = N::jet(witness256.inference_context(), Elements::Sha256Ctx8Add32); let digest_ctx = N::comp(&pair_ctx_witness, &add256).expect("consistent types"); - let finalize = N::jet(Elements::Sha256Ctx8Finalize); + let finalize = N::jet(witness256.inference_context(), Elements::Sha256Ctx8Finalize); N::comp(&digest_ctx, &finalize).expect("consistent types") } @@ -98,22 +103,27 @@ pub fn verify_bexp(input: &N, bexp: &N) -> N where N: CoreConstructible + JetConstructible, { + assert_eq!( + input.inference_context(), + bexp.inference_context(), + "cannot compose policy fragments with different type inference contexts", + ); let computed_bexp = N::comp(input, bexp).expect("consistent types"); - let verify = N::jet(Elements::Verify); + let verify = N::jet(input.inference_context(), Elements::Verify); N::comp(&computed_bexp, &verify).expect("consistent types") } -pub fn sha256(hash: &Pk::Sha256, witness: W) -> N +pub fn sha256(inference_context: &types::Context, hash: &Pk::Sha256, witness: W) -> N where Pk: ToXOnlyPubkey, N: CoreConstructible + JetConstructible + WitnessConstructible, { let hash_value = Value::u256_from_slice(Pk::to_sha256(hash).as_ref()); - let const_hash = N::const_word(hash_value); - let witness256 = N::witness(witness); + let const_hash = N::const_word(inference_context, hash_value); + let witness256 = N::witness(inference_context, witness); let computed_hash = compute_sha256(&witness256); let pair_hash_computed_hash = N::pair(&const_hash, &computed_hash).expect("consistent types"); - let eq256 = N::jet(Elements::Eq256); + let eq256 = N::jet(inference_context, Elements::Eq256); verify_bexp(&pair_hash_computed_hash, &eq256) } @@ -125,12 +135,12 @@ where N::comp(left, right).expect("consistent types") } -pub fn selector(witness_bit: W) -> N +pub fn selector(inference_context: &types::Context, witness_bit: W) -> N where N: CoreConstructible + WitnessConstructible, { - let witness = N::witness(witness_bit); - let unit = N::unit(); + let witness = N::witness(inference_context, witness_bit); + let unit = N::unit(inference_context); N::pair(&witness, &unit).expect("consistent types") } @@ -138,10 +148,15 @@ pub fn or(left: &N, right: &N, witness_bit: W) -> N where N: CoreConstructible + WitnessConstructible, { + assert_eq!( + left.inference_context(), + right.inference_context(), + "cannot compose policy fragments with different type inference contexts", + ); let drop_left = N::drop_(left); let drop_right = N::drop_(right); let case_left_right = N::case(&drop_left, &drop_right).expect("consistent types"); - let selector = selector(witness_bit); + let selector = selector(left.inference_context(), witness_bit); N::comp(&selector, &case_left_right).expect("consistent types") } @@ -154,14 +169,14 @@ where N: CoreConstructible + WitnessConstructible, { // 1 → 2 x 1 - let selector = selector(witness_bit); + let selector = selector(child.inference_context(), witness_bit); // 1 → 2^32 - let const_one = N::const_word(Value::u32(1)); + let const_one = N::const_word(child.inference_context(), Value::u32(1)); // 1 → 2^32 let child_one = N::comp(child, &const_one).expect("consistent types"); // 1 → 2^32 - let const_zero = N::const_word(Value::u32(0)); + let const_zero = N::const_word(child.inference_context(), Value::u32(0)); // 1 × 1 → 2^32 let drop_left = N::drop_(&const_zero); @@ -181,14 +196,19 @@ pub fn thresh_add(sum: &N, summand: &N) -> N where N: CoreConstructible + JetConstructible, { + assert_eq!( + sum.inference_context(), + summand.inference_context(), + "cannot compose policy fragments with different type inference contexts", + ); // 1 → 2^32 × 2^32 let pair_sum_summand = N::pair(sum, summand).expect("consistent types"); // 2^32 × 2^32 → 2 × 2^32 - let add32 = N::jet(Elements::Add32); + let add32 = N::jet(sum.inference_context(), Elements::Add32); // 1 → 2 x 2^32 let full_sum = N::comp(&pair_sum_summand, &add32).expect("consistent types"); // 2^32 → 2^32 - let iden = N::iden(); + let iden = N::iden(sum.inference_context()); // 2 × 2^32 → 2^32 let drop_iden = N::drop_(&iden); @@ -204,11 +224,11 @@ where N: CoreConstructible + JetConstructible, { // 1 → 2^32 - let const_k = N::const_word(Value::u32(k)); + let const_k = N::const_word(sum.inference_context(), Value::u32(k)); // 1 → 2^32 × 2^32 let pair_k_sum = N::pair(&const_k, sum).expect("consistent types"); // 2^32 × 2^32 → 2 - let eq32 = N::jet(Elements::Eq32); + let eq32 = N::jet(sum.inference_context(), Elements::Eq32); // 1 → 1 verify_bexp(&pair_k_sum, &eq32) diff --git a/src/types/arrow.rs b/src/types/arrow.rs index c18dc9e0..8fbee5f6 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -18,19 +18,30 @@ use crate::node::{ CoreConstructible, DisconnectConstructible, JetConstructible, NoDisconnect, WitnessConstructible, }; -use crate::types::{Bound, Error, Final, Type}; +use crate::types::{Context, Error, Final, Type}; use crate::{jet::Jet, Value}; use super::variable::new_name; /// A container for an expression's source and target types, whether or not /// these types are complete. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Arrow { /// The source type pub source: Type, /// The target type pub target: Type, + /// Type inference context for both types. + pub inference_context: Context, +} + +// Having `Clone` makes it easier to derive Clone on structures +// that contain Arrow, even though it is potentially confusing +// to use `.clone` to mean a shallow clone. +impl Clone for Arrow { + fn clone(&self) -> Self { + self.shallow_clone() + } } impl fmt::Display for Arrow { @@ -74,168 +85,61 @@ impl Arrow { }) } - /// Create a unification arrow for a fresh `unit` combinator - pub fn for_unit() -> Self { - Arrow { - source: Type::free(new_name("unit_src_")), - target: Type::unit(), - } - } - - /// Create a unification arrow for a fresh `iden` combinator - pub fn for_iden() -> Self { - // Throughout this module, when two types are the same, we reuse a - // pointer to them rather than creating distinct types and unifying - // them. This theoretically could lead to more confusing errors for - // the user during type inference, but in practice type inference - // is completely opaque and there's no harm in making it moreso. - let new = Type::free(new_name("iden_src_")); - Arrow { - source: new.shallow_clone(), - target: new, - } - } - - /// Create a unification arrow for a fresh `witness` combinator - pub fn for_witness() -> Self { - Arrow { - source: Type::free(new_name("witness_src_")), - target: Type::free(new_name("witness_tgt_")), - } - } - - /// Create a unification arrow for a fresh `fail` combinator - pub fn for_fail() -> Self { - Arrow { - source: Type::free(new_name("fail_src_")), - target: Type::free(new_name("fail_tgt_")), - } - } - - /// Create a unification arrow for a fresh jet combinator - pub fn for_jet(jet: J) -> Self { - Arrow { - source: jet.source_ty().to_type(), - target: jet.target_ty().to_type(), - } - } - - /// Create a unification arrow for a fresh const-word combinator - pub fn for_const_word(word: &Value) -> Self { - let len = word.len(); - assert!(len > 0, "Words must not be the empty bitstring"); - assert!(len.is_power_of_two()); - let depth = word.len().trailing_zeros(); - Arrow { - source: Type::unit(), - target: Type::two_two_n(depth as usize), - } - } - - /// Create a unification arrow for a fresh `injl` combinator - pub fn for_injl(child_arrow: &Arrow) -> Self { - Arrow { - source: child_arrow.source.shallow_clone(), - target: Type::sum( - child_arrow.target.shallow_clone(), - Type::free(new_name("injl_tgt_")), - ), - } - } - - /// Create a unification arrow for a fresh `injr` combinator - pub fn for_injr(child_arrow: &Arrow) -> Self { - Arrow { - source: child_arrow.source.shallow_clone(), - target: Type::sum( - Type::free(new_name("injr_tgt_")), - child_arrow.target.shallow_clone(), - ), - } - } - - /// Create a unification arrow for a fresh `take` combinator - pub fn for_take(child_arrow: &Arrow) -> Self { - Arrow { - source: Type::product( - child_arrow.source.shallow_clone(), - Type::free(new_name("take_src_")), - ), - target: child_arrow.target.shallow_clone(), - } - } - - /// Create a unification arrow for a fresh `drop` combinator - pub fn for_drop(child_arrow: &Arrow) -> Self { + /// Same as [`Self::clone`] but named to make it clearer that this is cheap + pub fn shallow_clone(&self) -> Self { Arrow { - source: Type::product( - Type::free(new_name("drop_src_")), - child_arrow.source.shallow_clone(), - ), - target: child_arrow.target.shallow_clone(), + source: self.source.shallow_clone(), + target: self.target.shallow_clone(), + inference_context: self.inference_context.shallow_clone(), } } - /// Create a unification arrow for a fresh `pair` combinator - pub fn for_pair(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { - lchild_arrow.source.unify( - &rchild_arrow.source, - "pair combinator: left source = right source", - )?; - Ok(Arrow { - source: lchild_arrow.source.shallow_clone(), - target: Type::product( - lchild_arrow.target.shallow_clone(), - rchild_arrow.target.shallow_clone(), - ), - }) - } - - /// Create a unification arrow for a fresh `comp` combinator - pub fn for_comp(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { - lchild_arrow.target.unify( - &rchild_arrow.source, - "comp combinator: left target = right source", - )?; - Ok(Arrow { - source: lchild_arrow.source.shallow_clone(), - target: rchild_arrow.target.shallow_clone(), - }) - } - /// Create a unification arrow for a fresh `case` combinator /// /// Either child may be `None`, in which case the combinator is assumed to be /// an assertion, which for type-inference purposes means there are no bounds /// on the missing child. /// - /// If neither child is provided, this function will not raise an error; it - /// is the responsibility of the caller to detect this case and error elsewhere. - pub fn for_case( - lchild_arrow: Option<&Arrow>, - rchild_arrow: Option<&Arrow>, - ) -> Result { - let a = Type::free(new_name("case_a_")); - let b = Type::free(new_name("case_b_")); - let c = Type::free(new_name("case_c_")); - - let sum_a_b = Type::sum(a.shallow_clone(), b.shallow_clone()); - let prod_sum_a_b_c = Type::product(sum_a_b, c.shallow_clone()); - - let target = Type::free(String::new()); + /// # Panics + /// + /// If neither child is provided, this function will panic. + fn for_case(lchild_arrow: Option<&Arrow>, rchild_arrow: Option<&Arrow>) -> Result { + if let (Some(left), Some(right)) = (lchild_arrow, rchild_arrow) { + left.inference_context.check_eq(&right.inference_context)?; + } + + let ctx = match (lchild_arrow, rchild_arrow) { + (Some(left), _) => left.inference_context.shallow_clone(), + (_, Some(right)) => right.inference_context.shallow_clone(), + (None, None) => panic!("called `for_case` with no children"), + }; + + let a = Type::free(&ctx, new_name("case_a_")); + let b = Type::free(&ctx, new_name("case_b_")); + let c = Type::free(&ctx, new_name("case_c_")); + + let sum_a_b = Type::sum(&ctx, a.shallow_clone(), b.shallow_clone()); + let prod_sum_a_b_c = Type::product(&ctx, sum_a_b, c.shallow_clone()); + + let target = Type::free(&ctx, String::new()); if let Some(lchild_arrow) = lchild_arrow { - lchild_arrow.source.bind( - Arc::new(Bound::Product(a, c.shallow_clone())), + ctx.bind_product( + &lchild_arrow.source, + &a, + &c, "case combinator: left source = A × C", )?; - target.unify(&lchild_arrow.target, "").unwrap(); + ctx.unify(&target, &lchild_arrow.target, "").unwrap(); } if let Some(rchild_arrow) = rchild_arrow { - rchild_arrow.source.bind( - Arc::new(Bound::Product(b, c)), + ctx.bind_product( + &rchild_arrow.source, + &b, + &c, "case combinator: left source = B × C", )?; - target.unify( + ctx.unify( + &target, &rchild_arrow.target, "case combinator: left target = right target", )?; @@ -244,71 +148,128 @@ impl Arrow { Ok(Arrow { source: prod_sum_a_b_c, target, + inference_context: ctx, }) } - /// Create a unification arrow for a fresh `comp` combinator - pub fn for_disconnect(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { - let a = Type::free(new_name("disconnect_a_")); - let b = Type::free(new_name("disconnect_b_")); + /// Helper function to combine code for the two `DisconnectConstructible` impls for [`Arrow`]. + fn for_disconnect(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { + lchild_arrow + .inference_context + .check_eq(&rchild_arrow.inference_context)?; + + let ctx = lchild_arrow.inference_context(); + let a = Type::free(ctx, new_name("disconnect_a_")); + let b = Type::free(ctx, new_name("disconnect_b_")); let c = rchild_arrow.source.shallow_clone(); let d = rchild_arrow.target.shallow_clone(); - let prod_256_a = Bound::Product(Type::two_two_n(8), a.shallow_clone()); - let prod_b_c = Bound::Product(b.shallow_clone(), c); - let prod_b_d = Type::product(b, d); - - lchild_arrow.source.bind( - Arc::new(prod_256_a), + ctx.bind_product( + &lchild_arrow.source, + &Type::two_two_n(ctx, 8), + &a, "disconnect combinator: left source = 2^256 × A", )?; - lchild_arrow.target.bind( - Arc::new(prod_b_c), + ctx.bind_product( + &lchild_arrow.target, + &b, + &c, "disconnect combinator: left target = B × C", )?; + let prod_b_d = Type::product(ctx, b, d); + Ok(Arrow { source: a, target: prod_b_d, + inference_context: lchild_arrow.inference_context.shallow_clone(), }) } - - /// Same as [`Self::clone`] but named to make it clearer that this is cheap - pub fn shallow_clone(&self) -> Self { - Arrow { - source: self.source.shallow_clone(), - target: self.target.shallow_clone(), - } - } } impl CoreConstructible for Arrow { - fn iden() -> Self { - Self::for_iden() + fn iden(inference_context: &Context) -> Self { + // Throughout this module, when two types are the same, we reuse a + // pointer to them rather than creating distinct types and unifying + // them. This theoretically could lead to more confusing errors for + // the user during type inference, but in practice type inference + // is completely opaque and there's no harm in making it moreso. + let new = Type::free(inference_context, new_name("iden_src_")); + Arrow { + source: new.shallow_clone(), + target: new, + inference_context: inference_context.shallow_clone(), + } } - fn unit() -> Self { - Self::for_unit() + fn unit(inference_context: &Context) -> Self { + Arrow { + source: Type::free(inference_context, new_name("unit_src_")), + target: Type::unit(inference_context), + inference_context: inference_context.shallow_clone(), + } } fn injl(child: &Self) -> Self { - Self::for_injl(child) + Arrow { + source: child.source.shallow_clone(), + target: Type::sum( + &child.inference_context, + child.target.shallow_clone(), + Type::free(&child.inference_context, new_name("injl_tgt_")), + ), + inference_context: child.inference_context.shallow_clone(), + } } fn injr(child: &Self) -> Self { - Self::for_injr(child) + Arrow { + source: child.source.shallow_clone(), + target: Type::sum( + &child.inference_context, + Type::free(&child.inference_context, new_name("injr_tgt_")), + child.target.shallow_clone(), + ), + inference_context: child.inference_context.shallow_clone(), + } } fn take(child: &Self) -> Self { - Self::for_take(child) + Arrow { + source: Type::product( + &child.inference_context, + child.source.shallow_clone(), + Type::free(&child.inference_context, new_name("take_src_")), + ), + target: child.target.shallow_clone(), + inference_context: child.inference_context.shallow_clone(), + } } fn drop_(child: &Self) -> Self { - Self::for_drop(child) + Arrow { + source: Type::product( + &child.inference_context, + Type::free(&child.inference_context, new_name("drop_src_")), + child.source.shallow_clone(), + ), + target: child.target.shallow_clone(), + inference_context: child.inference_context.shallow_clone(), + } } fn comp(left: &Self, right: &Self) -> Result { - Self::for_comp(left, right) + left.inference_context.check_eq(&right.inference_context)?; + left.inference_context.unify( + &left.target, + &right.source, + "comp combinator: left target = right source", + )?; + Ok(Arrow { + source: left.source.shallow_clone(), + target: right.target.shallow_clone(), + inference_context: left.inference_context.shallow_clone(), + }) } fn case(left: &Self, right: &Self) -> Result { @@ -324,15 +285,45 @@ impl CoreConstructible for Arrow { } fn pair(left: &Self, right: &Self) -> Result { - Self::for_pair(left, right) + left.inference_context.check_eq(&right.inference_context)?; + left.inference_context.unify( + &left.source, + &right.source, + "pair combinator: left source = right source", + )?; + Ok(Arrow { + source: left.source.shallow_clone(), + target: Type::product( + &left.inference_context, + left.target.shallow_clone(), + right.target.shallow_clone(), + ), + inference_context: left.inference_context.shallow_clone(), + }) } - fn fail(_: crate::FailEntropy) -> Self { - Self::for_fail() + fn fail(inference_context: &Context, _: crate::FailEntropy) -> Self { + Arrow { + source: Type::free(inference_context, new_name("fail_src_")), + target: Type::free(inference_context, new_name("fail_tgt_")), + inference_context: inference_context.shallow_clone(), + } + } + + fn const_word(inference_context: &Context, word: Arc) -> Self { + let len = word.len(); + assert!(len > 0, "Words must not be the empty bitstring"); + assert!(len.is_power_of_two()); + let depth = word.len().trailing_zeros(); + Arrow { + source: Type::unit(inference_context), + target: Type::two_two_n(inference_context, depth as usize), + inference_context: inference_context.shallow_clone(), + } } - fn const_word(word: Arc) -> Self { - Self::for_const_word(&word) + fn inference_context(&self) -> &Context { + &self.inference_context } } @@ -344,11 +335,14 @@ impl DisconnectConstructible for Arrow { impl DisconnectConstructible for Arrow { fn disconnect(left: &Self, _: &NoDisconnect) -> Result { + let source = Type::free(&left.inference_context, "disc_src".into()); + let target = Type::free(&left.inference_context, "disc_tgt".into()); Self::for_disconnect( left, &Arrow { - source: Type::free("disc_src".into()), - target: Type::free("disc_tgt".into()), + source, + target, + inference_context: left.inference_context.shallow_clone(), }, ) } @@ -364,13 +358,21 @@ impl DisconnectConstructible> for Arrow { } impl JetConstructible for Arrow { - fn jet(jet: J) -> Self { - Self::for_jet(jet) + fn jet(inference_context: &Context, jet: J) -> Self { + Arrow { + source: jet.source_ty().to_type(inference_context), + target: jet.target_ty().to_type(inference_context), + inference_context: inference_context.shallow_clone(), + } } } impl WitnessConstructible for Arrow { - fn witness(_: W) -> Self { - Self::for_witness() + fn witness(inference_context: &Context, _: W) -> Self { + Arrow { + source: Type::free(inference_context, new_name("witness_src_")), + target: Type::free(inference_context, new_name("witness_tgt_")), + inference_context: inference_context.shallow_clone(), + } } } diff --git a/src/types/context.rs b/src/types/context.rs new file mode 100644 index 00000000..efa2b898 --- /dev/null +++ b/src/types/context.rs @@ -0,0 +1,434 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Type Inference Context +//! +//! When constructing a Simplicity program, you must first create a type inference +//! context, in which type inference occurs incrementally during construction. Each +//! leaf node (e.g. `unit` and `iden`) must explicitly refer to the type inference +//! context, while combinator nodes (e.g. `comp`) infer the context from their +//! children, raising an error if there are multiple children whose contexts don't +//! match. +//! +//! This helps to prevent situations in which users attempt to construct multiple +//! independent programs, but types in one program accidentally refer to types in +//! the other. +//! + +use std::fmt; +use std::sync::{Arc, Mutex, MutexGuard}; + +use crate::dag::{Dag, DagLike}; + +use super::{Bound, CompleteBound, Error, Final, Type, TypeInner}; + +/// Type inference context, or handle to a context. +/// +/// Can be cheaply cloned with [`Context::shallow_clone`]. These clones will +/// refer to the same underlying type inference context, and can be used as +/// handles to each other. The derived [`Context::clone`] has the same effect. +/// +/// There is currently no way to create an independent context with the same +/// type inference variables (i.e. a deep clone). If you need this functionality, +/// please file an issue. +#[derive(Clone, Default)] +pub struct Context { + slab: Arc>>, +} + +impl fmt::Debug for Context { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let id = Arc::as_ptr(&self.slab) as usize; + write!(f, "inference_ctx_{:08x}", id) + } +} + +impl PartialEq for Context { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.slab, &other.slab) + } +} +impl Eq for Context {} + +impl Context { + /// Creates a new empty type inference context. + pub fn new() -> Self { + Context { + slab: Arc::new(Mutex::new(vec![])), + } + } + + /// Helper function to allocate a bound and return a reference to it. + fn alloc_bound(&self, bound: Bound) -> BoundRef { + let mut lock = self.lock(); + lock.alloc_bound(Arc::as_ptr(&self.slab), bound) + } + + /// Allocate a new free type bound, and return a reference to it. + pub fn alloc_free(&self, name: String) -> BoundRef { + self.alloc_bound(Bound::Free(name)) + } + + /// Allocate a new unit type bound, and return a reference to it. + pub fn alloc_unit(&self) -> BoundRef { + self.alloc_bound(Bound::Complete(Final::unit())) + } + + /// Allocate a new unit type bound, and return a reference to it. + pub fn alloc_complete(&self, data: Arc) -> BoundRef { + self.alloc_bound(Bound::Complete(data)) + } + + /// Allocate a new sum-type bound, and return a reference to it. + /// + /// # Panics + /// + /// Panics if either of the child types are from a different inference context. + pub fn alloc_sum(&self, left: Type, right: Type) -> BoundRef { + assert_eq!( + left.ctx, *self, + "left type did not match inference context of sum" + ); + assert_eq!( + right.ctx, *self, + "right type did not match inference context of sum" + ); + + let mut lock = self.lock(); + if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) { + lock.alloc_bound( + Arc::as_ptr(&self.slab), + Bound::Complete(Final::sum(data1, data2)), + ) + } else { + lock.alloc_bound(Arc::as_ptr(&self.slab), Bound::Sum(left.inner, right.inner)) + } + } + + /// Allocate a new product-type bound, and return a reference to it. + /// + /// # Panics + /// + /// Panics if either of the child types are from a different inference context. + pub fn alloc_product(&self, left: Type, right: Type) -> BoundRef { + assert_eq!( + left.ctx, *self, + "left type did not match inference context of product" + ); + assert_eq!( + right.ctx, *self, + "right type did not match inference context of product" + ); + + let mut lock = self.lock(); + if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) { + lock.alloc_bound( + Arc::as_ptr(&self.slab), + Bound::Complete(Final::product(data1, data2)), + ) + } else { + lock.alloc_bound( + Arc::as_ptr(&self.slab), + Bound::Product(left.inner, right.inner), + ) + } + } + + /// Creates a new handle to the context. + /// + /// This handle holds a reference to the underlying context and will keep + /// it alive. The context will only be dropped once all handles, including + /// the original context object, are dropped. + pub fn shallow_clone(&self) -> Self { + Self { + slab: Arc::clone(&self.slab), + } + } + + /// Checks whether two inference contexts are equal, and returns an error if not. + pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> { + if self == other { + Ok(()) + } else { + Err(super::Error::InferenceContextMismatch) + } + } + + /// Accesses a bound. + /// + /// # Panics + /// + /// Panics if passed a `BoundRef` that was not allocated by this context. + pub(super) fn get(&self, bound: &BoundRef) -> Bound { + bound.assert_matches_context(self); + let lock = self.lock(); + lock.slab[bound.index].shallow_clone() + } + + /// Reassigns a bound to a different bound. + /// + /// # Panics + /// + /// Panics if called on a complete type. This is a sanity-check to avoid + /// replacing already-completed types, which can cause inefficiencies in + /// the union-bound algorithm (and if our replacement changes the type, + /// this is probably a bug. + /// probably a bug. + /// + /// Also panics if passed a `BoundRef` that was not allocated by this context. + pub(super) fn reassign_non_complete(&self, bound: BoundRef, new: Bound) { + let mut lock = self.lock(); + lock.reassign_non_complete(bound, new); + } + + /// Binds the type to a product bound formed by the two inner types. If this + /// fails, attach the provided hint to the error. + /// + /// Fails if the type has an existing incompatible bound. + pub fn bind_product( + &self, + existing: &Type, + prod_l: &Type, + prod_r: &Type, + hint: &'static str, + ) -> Result<(), Error> { + assert_eq!(existing.ctx, *self); + assert_eq!(prod_l.ctx, *self); + assert_eq!(prod_r.ctx, *self); + + let existing_root = existing.inner.bound.root(); + let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone()); + + let mut lock = self.lock(); + lock.bind(existing_root, new_bound).map_err(|e| { + let new_bound = lock.alloc_bound(Arc::as_ptr(&self.slab), e.new); + Error::Bind { + existing_bound: Type::wrap_bound(self, e.existing), + new_bound: Type::wrap_bound(self, new_bound), + hint, + } + }) + } + + /// Unify the type with another one. + /// + /// Fails if the bounds on the two types are incompatible + pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> { + assert_eq!(ty1.ctx, *self); + assert_eq!(ty2.ctx, *self); + let mut lock = self.lock(); + lock.unify(&ty1.inner, &ty2.inner).map_err(|e| { + let new_bound = lock.alloc_bound(Arc::as_ptr(&self.slab), e.new); + Error::Bind { + existing_bound: Type::wrap_bound(self, e.existing), + new_bound: Type::wrap_bound(self, new_bound), + hint, + } + }) + } + + /// Locks the underlying slab mutex. + fn lock(&self) -> LockedContext { + LockedContext { + slab: self.slab.lock().unwrap(), + } + } +} + +#[derive(Debug, Clone)] +pub struct BoundRef { + context: *const Mutex>, + index: usize, +} + +impl BoundRef { + pub fn assert_matches_context(&self, ctx: &Context) { + assert_eq!( + self.context, + Arc::as_ptr(&ctx.slab), + "bound was accessed from a type inference context that did not create it", + ); + } + + /// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`] + /// with `PartialEq` and `Eq` implemented in terms of underlying pointer + /// equality. + pub fn occurs_check_id(&self) -> OccursCheckId { + OccursCheckId { + context: self.context, + index: self.index, + } + } +} + +impl super::PointerLike for BoundRef { + fn ptr_eq(&self, other: &Self) -> bool { + debug_assert_eq!( + self.context, other.context, + "tried to compare two bounds from different inference contexts" + ); + self.index == other.index + } + + fn shallow_clone(&self) -> Self { + BoundRef { + context: self.context, + index: self.index, + } + } +} + +impl<'ctx> DagLike for (&'ctx Context, BoundRef) { + type Node = BoundRef; + fn data(&self) -> &BoundRef { + &self.1 + } + + fn as_dag_node(&self) -> Dag { + match self.0.get(&self.1) { + Bound::Free(..) | Bound::Complete(..) => Dag::Nullary, + Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => { + Dag::Binary((self.0, ty1.bound.root()), (self.0, ty2.bound.root())) + } + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub struct OccursCheckId { + context: *const Mutex>, + // Will become an index into the context in a latter commit, but for + // now we set it to an Arc to preserve semantics. + index: usize, +} + +struct BindError { + existing: BoundRef, + new: Bound, +} + +/// Structure representing an inference context with its slab allocator mutex locked. +/// +/// This type is never exposed outside of this module and should only exist +/// ephemerally within function calls into this module. +struct LockedContext<'ctx> { + slab: MutexGuard<'ctx, Vec>, +} + +impl<'ctx> LockedContext<'ctx> { + fn alloc_bound(&mut self, ctx_ptr: *const Mutex>, bound: Bound) -> BoundRef { + self.slab.push(bound); + let index = self.slab.len() - 1; + + BoundRef { + context: ctx_ptr, + index, + } + } + + fn reassign_non_complete(&mut self, bound: BoundRef, new: Bound) { + assert!( + !matches!(self.slab[bound.index], Bound::Complete(..)), + "tried to modify finalized type", + ); + self.slab[bound.index] = new; + } + + /// It is a common situation that we are pairing two types, and in the + /// case that they are both complete, we want to pair the complete types. + /// + /// This method deals with all the annoying/complicated member variable + /// paths to get the actual complete data out. + fn complete_pair_data( + &self, + inn1: &TypeInner, + inn2: &TypeInner, + ) -> Option<(Arc, Arc)> { + let bound1 = &self.slab[inn1.bound.root().index]; + let bound2 = &self.slab[inn2.bound.root().index]; + if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) { + Some((Arc::clone(data1), Arc::clone(data2))) + } else { + None + } + } + + /// Unify the type with another one. + /// + /// Fails if the bounds on the two types are incompatible + fn unify(&mut self, existing: &TypeInner, other: &TypeInner) -> Result<(), BindError> { + existing.bound.unify(&other.bound, |x_bound, y_bound| { + self.bind(x_bound, self.slab[y_bound.index].shallow_clone()) + }) + } + + fn bind(&mut self, existing: BoundRef, new: Bound) -> Result<(), BindError> { + let existing_bound = self.slab[existing.index].shallow_clone(); + let bind_error = || BindError { + existing: existing.clone(), + new: new.shallow_clone(), + }; + + match (&existing_bound, &new) { + // Binding a free type to anything is a no-op + (_, Bound::Free(_)) => Ok(()), + // Free types are simply dropped and replaced by the new bound + (Bound::Free(_), _) => { + // Free means non-finalized, so set() is ok. + self.reassign_non_complete(existing, new); + Ok(()) + } + // Binding complete->complete shouldn't ever happen, but if so, we just + // compare the two types and return a pass/fail + (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => { + if existing_final == new_final { + Ok(()) + } else { + Err(bind_error()) + } + } + // Binding an incomplete to a complete type requires recursion. + (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => { + match (complete.bound(), incomplete) { + // A unit might match a Bound::Free(..) or a Bound::Complete(..), + // and both cases were handled above. So this is an error. + (CompleteBound::Unit, _) => Err(bind_error()), + ( + CompleteBound::Product(ref comp1, ref comp2), + Bound::Product(ref ty1, ref ty2), + ) + | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => { + let bound1 = ty1.bound.root(); + let bound2 = ty2.bound.root(); + self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?; + self.bind(bound2, Bound::Complete(Arc::clone(comp2))) + } + _ => Err(bind_error()), + } + } + (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2)) + | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => { + self.unify(x1, y1)?; + self.unify(x2, y2)?; + // This type was not complete, but it may be after unification, giving us + // an opportunity to finaliize it. We do this eagerly to make sure that + // "complete" (no free children) is always equivalent to "finalized" (the + // bound field having variant Bound::Complete(..)), even during inference. + // + // It also gives the user access to more information about the type, + // prior to finalization. + if let Some((data1, data2)) = self.complete_pair_data(y1, y2) { + self.reassign_non_complete( + existing, + Bound::Complete(if let Bound::Sum(..) = existing_bound { + Final::sum(data1, data2) + } else { + Final::product(data1, data2) + }), + ); + } + Ok(()) + } + (_, _) => Err(bind_error()), + } + } +} diff --git a/src/types/final_data.rs b/src/types/final_data.rs index 1bfc06e9..8858edd3 100644 --- a/src/types/final_data.rs +++ b/src/types/final_data.rs @@ -12,7 +12,6 @@ //! use crate::dag::{Dag, DagLike, NoSharing}; -use crate::types::{Bound, Type}; use crate::Tmr; use std::sync::Arc; @@ -163,7 +162,7 @@ impl Final { /// /// The type is precomputed and fast to access. pub fn two_two_n(n: usize) -> Arc { - super::precomputed::nth_power_of_2(n).final_data().unwrap() + super::precomputed::nth_power_of_2(n) } /// Create the sum of the given `left` and `right` types. @@ -227,12 +226,6 @@ impl Final { } } -impl From> for Type { - fn from(value: Arc) -> Self { - Type::from(Bound::Complete(value)) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/types/mod.rs b/src/types/mod.rs index 3aad8f91..47220853 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -70,8 +70,8 @@ //! or a sum or product of other types. //! -use self::union_bound::UbElement; -use crate::dag::{Dag, DagLike, NoSharing}; +use self::union_bound::{PointerLike, UbElement}; +use crate::dag::{DagLike, NoSharing}; use crate::Tmr; use std::collections::HashSet; @@ -79,11 +79,13 @@ use std::fmt; use std::sync::Arc; pub mod arrow; +mod context; mod final_data; mod precomputed; mod union_bound; mod variable; +pub use context::{BoundRef, Context}; pub use final_data::{CompleteBound, Final}; /// Error type for simplicity @@ -92,8 +94,8 @@ pub use final_data::{CompleteBound, Final}; pub enum Error { /// An attempt to bind a type conflicted with an existing bound on the type Bind { - existing_bound: Bound, - new_bound: Bound, + existing_bound: Type, + new_bound: Type, hint: &'static str, }, /// Two unequal complete types were attempted to be unified @@ -103,7 +105,10 @@ pub enum Error { hint: &'static str, }, /// A type is recursive (i.e., occurs within itself), violating the "occurs check" - OccursCheck { infinite_bound: Arc }, + OccursCheck { infinite_bound: Type }, + /// Attempted to combine two nodes which had different type inference + /// contexts. This is probably a programming error. + InferenceContextMismatch, } impl fmt::Display for Error { @@ -134,137 +139,26 @@ impl fmt::Display for Error { Error::OccursCheck { infinite_bound } => { write!(f, "infinitely-sized type {}", infinite_bound,) } + Error::InferenceContextMismatch => { + f.write_str("attempted to combine two nodes with different type inference contexts") + } } } } impl std::error::Error for Error {} -mod bound_mutex { - use super::{Bound, CompleteBound, Error, Final}; - use std::fmt; - use std::sync::{Arc, Mutex}; - - /// Source or target type of a Simplicity expression - pub struct BoundMutex { - /// The type's status according to the union-bound algorithm. - inner: Mutex>, - } - - impl fmt::Debug for BoundMutex { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.get().fmt(f) - } - } - - impl BoundMutex { - pub fn new(bound: Bound) -> Self { - BoundMutex { - inner: Mutex::new(Arc::new(bound)), - } - } - - pub fn get(&self) -> Arc { - Arc::clone(&self.inner.lock().unwrap()) - } - - pub fn set(&self, new: Arc) { - let mut lock = self.inner.lock().unwrap(); - assert!( - !matches!(**lock, Bound::Complete(..)), - "tried to modify finalized type", - ); - *lock = new; - } - - pub fn bind(&self, bound: Arc, hint: &'static str) -> Result<(), Error> { - let existing_bound = self.get(); - let bind_error = || Error::Bind { - existing_bound: existing_bound.shallow_clone(), - new_bound: bound.shallow_clone(), - hint, - }; - - match (existing_bound.as_ref(), bound.as_ref()) { - // Binding a free type to anything is a no-op - (_, Bound::Free(_)) => Ok(()), - // Free types are simply dropped and replaced by the new bound - (Bound::Free(_), _) => { - // Free means non-finalized, so set() is ok. - self.set(bound); - Ok(()) - } - // Binding complete->complete shouldn't ever happen, but if so, we just - // compare the two types and return a pass/fail - (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => { - if existing_final == new_final { - Ok(()) - } else { - Err(bind_error()) - } - } - // Binding an incomplete to a complete type requires recursion. - (Bound::Complete(complete), incomplete) - | (incomplete, Bound::Complete(complete)) => { - match (complete.bound(), incomplete) { - // A unit might match a Bound::Free(..) or a Bound::Complete(..), - // and both cases were handled above. So this is an error. - (CompleteBound::Unit, _) => Err(bind_error()), - ( - CompleteBound::Product(ref comp1, ref comp2), - Bound::Product(ref ty1, ref ty2), - ) - | ( - CompleteBound::Sum(ref comp1, ref comp2), - Bound::Sum(ref ty1, ref ty2), - ) => { - ty1.bind(Arc::new(Bound::Complete(Arc::clone(comp1))), hint)?; - ty2.bind(Arc::new(Bound::Complete(Arc::clone(comp2))), hint) - } - _ => Err(bind_error()), - } - } - (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2)) - | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => { - x1.unify(y1, hint)?; - x2.unify(y2, hint)?; - // This type was not complete, but it may be after unification, giving us - // an opportunity to finaliize it. We do this eagerly to make sure that - // "complete" (no free children) is always equivalent to "finalized" (the - // bound field having variant Bound::Complete(..)), even during inference. - // - // It also gives the user access to more information about the type, - // prior to finalization. - if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) { - self.set(Arc::new(Bound::Complete(if let Bound::Sum(..) = *bound { - Final::sum(data1, data2) - } else { - Final::product(data1, data2) - }))); - } - Ok(()) - } - (x, y) => Err(Error::Bind { - existing_bound: x.shallow_clone(), - new_bound: y.shallow_clone(), - hint, - }), - } - } - } -} - /// The state of a [`Type`] based on all constraints currently imposed on it. #[derive(Clone)] -pub enum Bound { +enum Bound { /// Fully-unconstrained type Free(String), /// Fully-constrained (i.e. complete) type, which has no free variables. Complete(Arc), /// A sum of two other types - Sum(Type, Type), + Sum(TypeInner, TypeInner), /// A product of two other types - Product(Type, Type), + Product(TypeInner, TypeInner), } impl Bound { @@ -276,102 +170,6 @@ impl Bound { pub fn shallow_clone(&self) -> Bound { self.clone() } - - fn free(name: String) -> Self { - Bound::Free(name) - } - - fn unit() -> Self { - Bound::Complete(Final::unit()) - } - - fn sum(a: Type, b: Type) -> Self { - if let (Some(adata), Some(bdata)) = (a.final_data(), b.final_data()) { - Bound::Complete(Final::sum(adata, bdata)) - } else { - Bound::Sum(a, b) - } - } - - fn product(a: Type, b: Type) -> Self { - if let (Some(adata), Some(bdata)) = (a.final_data(), b.final_data()) { - Bound::Complete(Final::product(adata, bdata)) - } else { - Bound::Product(a, b) - } - } -} - -const MAX_DISPLAY_DEPTH: usize = 64; - -impl fmt::Debug for Bound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let arc = Arc::new(self.shallow_clone()); - for data in arc.verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) { - if data.depth == MAX_DISPLAY_DEPTH { - if data.n_children_yielded == 0 { - f.write_str("...")?; - } - continue; - } - match (&*data.node, data.n_children_yielded) { - (Bound::Free(ref s), _) => f.write_str(s)?, - (Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?, - (Bound::Sum(..), 0) | (Bound::Product(..), 0) => f.write_str("(")?, - (Bound::Sum(..), 2) | (Bound::Product(..), 2) => f.write_str(")")?, - (Bound::Sum(..), _) => f.write_str(" + ")?, - (Bound::Product(..), _) => f.write_str(" × ")?, - } - } - Ok(()) - } -} - -impl fmt::Display for Bound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let arc = Arc::new(self.shallow_clone()); - for data in arc.verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) { - if data.depth == MAX_DISPLAY_DEPTH { - if data.n_children_yielded == 0 { - f.write_str("...")?; - } - continue; - } - match (&*data.node, data.n_children_yielded) { - (Bound::Free(ref s), _) => f.write_str(s)?, - (Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?, - (Bound::Sum(..), 0) | (Bound::Product(..), 0) => { - if data.index > 0 { - f.write_str("(")?; - } - } - (Bound::Sum(..), 2) | (Bound::Product(..), 2) => { - if data.index > 0 { - f.write_str(")")? - } - } - (Bound::Sum(..), _) => f.write_str(" + ")?, - (Bound::Product(..), _) => f.write_str(" × ")?, - } - } - Ok(()) - } -} - -impl DagLike for Arc { - type Node = Bound; - fn data(&self) -> &Bound { - self - } - - fn as_dag_node(&self) -> Dag { - match **self { - Bound::Free(..) | Bound::Complete(..) => Dag::Nullary, - Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => { - Dag::Binary(ty1.bound.root().get(), ty2.bound.root().get()) - } - } - } } /// Source or target type of a Simplicity expression. @@ -380,39 +178,68 @@ impl DagLike for Arc { /// therefore quite cheap to clone, but be aware that cloning will not /// actually create a new independent type, just a second pointer to the /// first one. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Type { + /// Handle to the type context. + ctx: Context, + /// The actual contents of the type. + inner: TypeInner, +} + +#[derive(Clone)] +struct TypeInner { /// A set of constraints, which maintained by the union-bound algorithm and /// is progressively tightened as type inference proceeds. - bound: UbElement, + bound: UbElement, +} + +impl TypeInner { + fn shallow_clone(&self) -> Self { + self.clone() + } } impl Type { /// Return an unbound type with the given name - pub fn free(name: String) -> Self { - Type::from(Bound::free(name)) + pub fn free(ctx: &Context, name: String) -> Self { + Self::wrap_bound(ctx, ctx.alloc_free(name)) } /// Create the unit type. - pub fn unit() -> Self { - Type::from(Bound::unit()) + pub fn unit(ctx: &Context) -> Self { + Self::wrap_bound(ctx, ctx.alloc_unit()) } /// Create the type `2^(2^n)` for the given `n`. /// /// The type is precomputed and fast to access. - pub fn two_two_n(n: usize) -> Self { - precomputed::nth_power_of_2(n) + pub fn two_two_n(ctx: &Context, n: usize) -> Self { + Self::complete(ctx, precomputed::nth_power_of_2(n)) } /// Create the sum of the given `left` and `right` types. - pub fn sum(left: Self, right: Self) -> Self { - Type::from(Bound::sum(left, right)) + pub fn sum(ctx: &Context, left: Self, right: Self) -> Self { + Self::wrap_bound(ctx, ctx.alloc_sum(left, right)) } /// Create the product of the given `left` and `right` types. - pub fn product(left: Self, right: Self) -> Self { - Type::from(Bound::product(left, right)) + pub fn product(ctx: &Context, left: Self, right: Self) -> Self { + Self::wrap_bound(ctx, ctx.alloc_product(left, right)) + } + + /// Create a complete type. + pub fn complete(ctx: &Context, final_data: Arc) -> Self { + Self::wrap_bound(ctx, ctx.alloc_complete(final_data)) + } + + fn wrap_bound(ctx: &Context, bound: BoundRef) -> Self { + bound.assert_matches_context(ctx); + Type { + ctx: ctx.shallow_clone(), + inner: TypeInner { + bound: UbElement::new(bound), + }, + } } /// Clones the `Type`. @@ -423,27 +250,9 @@ impl Type { self.clone() } - /// Binds the type to a given bound. If this fails, attach the provided - /// hint to the error. - /// - /// Fails if the type has an existing incompatible bound. - pub fn bind(&self, bound: Arc, hint: &'static str) -> Result<(), Error> { - let root = self.bound.root(); - root.bind(bound, hint) - } - - /// Unify the type with another one. - /// - /// Fails if the bounds on the two types are incompatible - pub fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> { - self.bound.unify(&other.bound, |x_bound, y_bound| { - x_bound.bind(y_bound.get(), hint) - }) - } - /// Accessor for this type's bound - pub fn bound(&self) -> Arc { - self.bound.root().get() + fn bound(&self) -> Bound { + self.ctx.get(&self.inner.bound.root()) } /// Accessor for the TMR of this type, if it is final @@ -453,7 +262,7 @@ impl Type { /// Accessor for the data of this type, if it is complete pub fn final_data(&self) -> Option> { - if let Bound::Complete(ref data) = *self.bound.root().get() { + if let Bound::Complete(ref data) = self.bound() { Some(Arc::clone(data)) } else { None @@ -466,55 +275,57 @@ impl Type { /// complete, since its children may have been unified to a complete type. To /// ensure a type is complete, call [`Type::finalize`]. pub fn is_final(&self) -> bool { - matches!(*self.bound.root().get(), Bound::Complete(..)) + self.final_data().is_some() } /// Attempts to finalize the type. Returns its TMR on success. pub fn finalize(&self) -> Result, Error> { + use context::OccursCheckId; + /// Helper type for the occurs-check. enum OccursCheckStack { - Iterate(Arc), - Complete(*const Bound), + Iterate(BoundRef), + Complete(OccursCheckId), } // Done with sharing tracker. Actual algorithm follows. - let root = self.bound.root(); - let bound = root.get(); - if let Bound::Complete(ref data) = *bound { + let root = self.inner.bound.root(); + let bound = self.ctx.get(&root); + if let Bound::Complete(ref data) = bound { return Ok(Arc::clone(data)); } // First, do occurs-check to ensure that we have no infinitely sized types. - let mut stack = vec![OccursCheckStack::Iterate(Arc::clone(&bound))]; + let mut stack = vec![OccursCheckStack::Iterate(root)]; let mut in_progress = HashSet::new(); let mut completed = HashSet::new(); while let Some(top) = stack.pop() { let bound = match top { - OccursCheckStack::Complete(ptr) => { - in_progress.remove(&ptr); - completed.insert(ptr); + OccursCheckStack::Complete(id) => { + in_progress.remove(&id); + completed.insert(id); continue; } OccursCheckStack::Iterate(b) => b, }; - let ptr = bound.as_ref() as *const _; - if completed.contains(&ptr) { + let id = bound.occurs_check_id(); + if completed.contains(&id) { // Once we have iterated through a type, we don't need to check it again. // Without this shortcut the occurs-check would take exponential time. continue; } - if !in_progress.insert(ptr) { + if !in_progress.insert(id) { return Err(Error::OccursCheck { - infinite_bound: bound, + infinite_bound: Type::wrap_bound(&self.ctx, bound), }); } - stack.push(OccursCheckStack::Complete(ptr)); - if let Some(child) = bound.right_child() { + stack.push(OccursCheckStack::Complete(id)); + if let Some((_, child)) = (&self.ctx, bound.shallow_clone()).right_child() { stack.push(OccursCheckStack::Iterate(child)); } - if let Some(child) = bound.left_child() { + if let Some((_, child)) = (&self.ctx, bound).left_child() { stack.push(OccursCheckStack::Iterate(child)); } } @@ -522,10 +333,9 @@ impl Type { // Now that we know our types have finite size, we can safely use a // post-order iterator to finalize them. let mut finalized = vec![]; - for data in self.shallow_clone().post_order_iter::() { - let bound = data.node.bound.root(); - let bound_get = bound.get(); - let final_data = match *bound_get { + for data in (&self.ctx, self.inner.bound.root()).post_order_iter::() { + let bound_get = data.node.0.get(&data.node.1); + let final_data = match bound_get { Bound::Free(_) => Final::unit(), Bound::Complete(ref arc) => Arc::clone(arc), Bound::Sum(..) => Final::sum( @@ -538,9 +348,9 @@ impl Type { ), }; - if !matches!(*bound_get, Bound::Complete(..)) { - // set() ok because we are if-guarded on this variable not being complete - bound.set(Arc::new(Bound::Complete(Arc::clone(&final_data)))); + if !matches!(bound_get, Bound::Complete(..)) { + self.ctx + .reassign_non_complete(data.node.1, Bound::Complete(Arc::clone(&final_data))); } finalized.push(final_data); } @@ -548,47 +358,84 @@ impl Type { } /// Return a vector containing the types 2^(2^i) for i from 0 to n-1. - pub fn powers_of_two(n: usize) -> Vec { + pub fn powers_of_two(ctx: &Context, n: usize) -> Vec { let mut ret = Vec::with_capacity(n); - let unit = Type::unit(); - let mut two = Type::sum(unit.shallow_clone(), unit); + let unit = Type::unit(ctx); + let mut two = Type::sum(ctx, unit.shallow_clone(), unit); for _ in 0..n { ret.push(two.shallow_clone()); - two = Type::product(two.shallow_clone(), two); + two = Type::product(ctx, two.shallow_clone(), two); } ret } } -impl fmt::Display for Type { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Display::fmt(&self.bound.root().get(), f) - } -} +const MAX_DISPLAY_DEPTH: usize = 64; -impl From for Type { - /// Promotes a `Bound` to a type defined by that constraint - fn from(bound: Bound) -> Type { - Type { - bound: UbElement::new(Arc::new(bound_mutex::BoundMutex::new(bound))), +impl fmt::Debug for Type { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for data in (&self.ctx, self.inner.bound.root()) + .verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) + { + if data.depth == MAX_DISPLAY_DEPTH { + if data.n_children_yielded == 0 { + f.write_str("...")?; + } + continue; + } + let bound = data.node.0.get(&data.node.1); + match (bound, data.n_children_yielded) { + (Bound::Free(ref s), _) => f.write_str(s)?, + (Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?, + (Bound::Sum(..), 0) | (Bound::Product(..), 0) => { + if data.index > 0 { + f.write_str("(")?; + } + } + (Bound::Sum(..), 2) | (Bound::Product(..), 2) => { + if data.index > 0 { + f.write_str(")")? + } + } + (Bound::Sum(..), _) => f.write_str(" + ")?, + (Bound::Product(..), _) => f.write_str(" × ")?, + } } + Ok(()) } } -impl DagLike for Type { - type Node = Type; - fn data(&self) -> &Type { - self - } - - fn as_dag_node(&self) -> Dag { - match *self.bound.root().get() { - Bound::Free(..) | Bound::Complete(..) => Dag::Nullary, - Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => { - Dag::Binary(ty1.shallow_clone(), ty2.shallow_clone()) +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for data in (&self.ctx, self.inner.bound.root()) + .verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) + { + if data.depth == MAX_DISPLAY_DEPTH { + if data.n_children_yielded == 0 { + f.write_str("...")?; + } + continue; + } + let bound = data.node.0.get(&data.node.1); + match (bound, data.n_children_yielded) { + (Bound::Free(ref s), _) => f.write_str(s)?, + (Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?, + (Bound::Sum(..), 0) | (Bound::Product(..), 0) => { + if data.index > 0 { + f.write_str("(")?; + } + } + (Bound::Sum(..), 2) | (Bound::Product(..), 2) => { + if data.index > 0 { + f.write_str(")")? + } + } + (Bound::Sum(..), _) => f.write_str(" + ")?, + (Bound::Product(..), _) => f.write_str(" × ")?, } } + Ok(()) } } @@ -601,8 +448,10 @@ mod tests { #[test] fn inference_failure() { + let ctx = Context::new(); + // unit: A -> 1 - let unit = Arc::>::unit(); // 1 -> 1 + let unit = Arc::>::unit(&ctx); // 1 -> 1 // Force unit to be 1->1 Arc::>::comp(&unit, &unit).unwrap(); @@ -618,7 +467,8 @@ mod tests { #[test] fn memory_leak() { - let iden = Arc::>::iden(); + let ctx = Context::new(); + let iden = Arc::>::iden(&ctx); let drop = Arc::>::drop_(&iden); let case = Arc::>::case(&iden, &drop).unwrap(); diff --git a/src/types/precomputed.rs b/src/types/precomputed.rs index 86d6a683..fb67ae32 100644 --- a/src/types/precomputed.rs +++ b/src/types/precomputed.rs @@ -14,33 +14,34 @@ use crate::Tmr; -use super::Type; +use super::Final; use std::cell::RefCell; use std::convert::TryInto; +use std::sync::Arc; // Directly use the size of the precomputed TMR table to make sure they're in sync. const N_POWERS: usize = Tmr::POWERS_OF_TWO.len(); thread_local! { - static POWERS_OF_TWO: RefCell> = RefCell::new(None); + static POWERS_OF_TWO: RefCell; N_POWERS]>> = RefCell::new(None); } -fn initialize(write: &mut Option<[Type; N_POWERS]>) { - let one = Type::unit(); +fn initialize(write: &mut Option<[Arc; N_POWERS]>) { + let one = Final::unit(); let mut powers = Vec::with_capacity(N_POWERS); // Two^(2^0) = Two = (One + One) - let mut power = Type::sum(one.shallow_clone(), one); - powers.push(power.shallow_clone()); + let mut power = Final::sum(Arc::clone(&one), one); + powers.push(Arc::clone(&power)); // Two^(2^(i + 1)) = (Two^(2^i) * Two^(2^i)) for _ in 1..N_POWERS { - power = Type::product(power.shallow_clone(), power); - powers.push(power.shallow_clone()); + power = Final::product(Arc::clone(&power), power); + powers.push(Arc::clone(&power)); } - let powers: [Type; N_POWERS] = powers.try_into().unwrap(); + let powers: [Arc; N_POWERS] = powers.try_into().unwrap(); *write = Some(powers); } @@ -49,12 +50,12 @@ fn initialize(write: &mut Option<[Type; N_POWERS]>) { /// # Panics /// /// Panics if you request a number `n` greater than or equal to [`Tmr::POWERS_OF_TWO`]. -pub fn nth_power_of_2(n: usize) -> Type { +pub fn nth_power_of_2(n: usize) -> Arc { POWERS_OF_TWO.with(|arr| { if arr.borrow().is_none() { initialize(&mut arr.borrow_mut()); } debug_assert!(arr.borrow().is_some()); - arr.borrow().as_ref().unwrap()[n].shallow_clone() + Arc::clone(&arr.borrow().as_ref().unwrap()[n]) }) } diff --git a/src/types/union_bound.rs b/src/types/union_bound.rs index 2a984679..f47438d9 100644 --- a/src/types/union_bound.rs +++ b/src/types/union_bound.rs @@ -32,6 +32,30 @@ use std::sync::{Arc, Mutex}; use std::{cmp, fmt, mem}; +/// Trait describing objects that can be stored and manipulated by the union-bound +/// algorithm. +/// +/// Because the algorithm depends on identity equality (i.e. two objects being +/// exactly the same in memory) such objects need to have such a notion of +/// equality. In general this differs from the `Eq` trait which implements +/// "semantic equality". +pub trait PointerLike { + /// Whether two objects are the same. + fn ptr_eq(&self, other: &Self) -> bool; + + /// A "shallow copy" of the object. + fn shallow_clone(&self) -> Self; +} + +impl PointerLike for Arc { + fn ptr_eq(&self, other: &Self) -> bool { + Arc::ptr_eq(self, other) + } + fn shallow_clone(&self) -> Self { + Arc::clone(self) + } +} + pub struct UbElement { inner: Arc>>, } @@ -44,7 +68,7 @@ impl Clone for UbElement { } } -impl fmt::Debug for UbElement { +impl fmt::Debug for UbElement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&self.root(), f) } @@ -56,29 +80,31 @@ struct UbInner { } enum UbData { - Root(Arc), + Root(T), EqualTo(UbElement), } impl UbData { - fn shallow_clone(&self) -> Self { + fn unwrap_root(&self) -> &T { match self { - UbData::Root(x) => UbData::Root(Arc::clone(x)), - UbData::EqualTo(eq) => UbData::EqualTo(eq.shallow_clone()), + UbData::Root(ref x) => x, + UbData::EqualTo(..) => unreachable!(), } } +} - fn unwrap_root(&self) -> &Arc { +impl UbData { + fn shallow_clone(&self) -> Self { match self { - UbData::Root(ref x) => x, - UbData::EqualTo(..) => unreachable!(), + UbData::Root(x) => UbData::Root(x.shallow_clone()), + UbData::EqualTo(eq) => UbData::EqualTo(eq.shallow_clone()), } } } impl UbElement { /// Turns an existing piece of data into a singleton union-bound set. - pub fn new(data: Arc) -> Self { + pub fn new(data: T) -> Self { UbElement { inner: Arc::new(Mutex::new(UbInner { data: UbData::Root(data), @@ -92,14 +118,18 @@ impl UbElement { /// This is the same as just calling `.clone()` but has a different name to /// emphasize that what's being cloned is internally just an Arc. pub fn shallow_clone(&self) -> Self { - self.clone() + Self { + inner: Arc::clone(&self.inner), + } } +} +impl UbElement { /// Find the representative of this object in its disjoint set. - pub fn root(&self) -> Arc { + pub fn root(&self) -> T { let root = self.root_element(); let inner_lock = root.inner.lock().unwrap(); - Arc::clone(inner_lock.data.unwrap_root()) + inner_lock.data.unwrap_root().shallow_clone() } /// Find the representative of this object in its disjoint set. @@ -145,7 +175,7 @@ impl UbElement { /// to actually be equal. This is accomplished with the `bind_fn` function, /// which takes two arguments: the **new representative that will be kept** /// followed by the **old representative that will be dropped**. - pub fn unify, &Arc) -> Result<(), E>>( + pub fn unify Result<(), E>>( &self, other: &Self, bind_fn: Bind, @@ -167,7 +197,7 @@ impl UbElement { // If our two variables are not literally the same, but through // unification have become the same, we detect _this_ and exit early. - if Arc::ptr_eq(x_lock.data.unwrap_root(), y_lock.data.unwrap_root()) { + if x_lock.data.unwrap_root().ptr_eq(y_lock.data.unwrap_root()) { return Ok(()); } @@ -197,7 +227,7 @@ impl UbElement { } let x_data = match x_lock.data { - UbData::Root(ref arc) => Arc::clone(arc), + UbData::Root(ref data) => data.shallow_clone(), UbData::EqualTo(..) => unreachable!(), }; drop(x_lock);