diff --git a/src/types/context.rs b/src/types/context.rs index 7e9da28b..249b9bea 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -15,12 +15,12 @@ //! use std::fmt; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use crate::dag::{Dag, DagLike}; use super::bound_mutex::BoundMutex; -use super::{Bound, Error, Final, Type}; +use super::{Bound, CompleteBound, Error, Final, Type}; /// Type inference context, or handle to a context. /// @@ -156,14 +156,24 @@ impl Context { /// /// Fails if the type has an existing incompatible bound. pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> { - existing.bind(new, hint) + let existing_root = existing.bound.root(); + let lock = self.lock(); + lock.bind(existing_root, new, hint) } /// Unify the type with another one. /// /// Fails if the bounds on the two types are incompatible pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> { - ty1.unify(ty2, hint) + let lock = self.lock(); + lock.unify(ty1, ty2, hint) + } + + /// Locks the underlying slab mutex. + fn lock(&self) -> LockedContext { + LockedContext { + slab: self.slab.lock().unwrap(), + } } } @@ -184,10 +194,6 @@ impl BoundRef { ); } - pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { - self.index.bind(bound, hint) - } - /// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`] /// with `PartialEq` and `Eq` implemented in terms of underlying pointer /// equality. @@ -239,3 +245,97 @@ pub struct OccursCheckId { // now we set it to an Arc to preserve semantics. index: *const BoundMutex, } + +/// Structure representing an inference context with its slab allocator mutex locked. +/// +/// This type is never exposed outside of this module and should only exist +/// ephemerally within function calls into this module. +struct LockedContext<'ctx> { + slab: MutexGuard<'ctx, Vec>, +} + +impl<'ctx> LockedContext<'ctx> { + /// Unify the type with another one. + /// + /// Fails if the bounds on the two types are incompatible + fn unify(&self, existing: &Type, other: &Type, hint: &'static str) -> Result<(), Error> { + existing.bound.unify(&other.bound, |x_bound, y_bound| { + self.bind(x_bound, y_bound.index.get(), hint) + }) + } + + fn bind(&self, existing: BoundRef, new: Bound, hint: &'static str) -> Result<(), Error> { + let existing_bound = existing.index.get(); + let bind_error = || Error::Bind { + existing_bound: existing_bound.shallow_clone(), + new_bound: new.shallow_clone(), + hint, + }; + + match (&existing_bound, &new) { + // Binding a free type to anything is a no-op + (_, Bound::Free(_)) => Ok(()), + // Free types are simply dropped and replaced by the new bound + (Bound::Free(_), _) => { + // Free means non-finalized, so set() is ok. + existing.index.set(new); + Ok(()) + } + // Binding complete->complete shouldn't ever happen, but if so, we just + // compare the two types and return a pass/fail + (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => { + if existing_final == new_final { + Ok(()) + } else { + Err(bind_error()) + } + } + // Binding an incomplete to a complete type requires recursion. + (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => { + match (complete.bound(), incomplete) { + // A unit might match a Bound::Free(..) or a Bound::Complete(..), + // and both cases were handled above. So this is an error. + (CompleteBound::Unit, _) => Err(bind_error()), + ( + CompleteBound::Product(ref comp1, ref comp2), + Bound::Product(ref ty1, ref ty2), + ) + | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => { + let bound1 = ty1.bound.root(); + let bound2 = ty2.bound.root(); + self.bind(bound1, Bound::Complete(Arc::clone(comp1)), hint)?; + self.bind(bound2, Bound::Complete(Arc::clone(comp2)), hint) + } + _ => Err(bind_error()), + } + } + (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2)) + | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => { + self.unify(x1, y1, hint)?; + self.unify(x2, y2, hint)?; + // This type was not complete, but it may be after unification, giving us + // an opportunity to finaliize it. We do this eagerly to make sure that + // "complete" (no free children) is always equivalent to "finalized" (the + // bound field having variant Bound::Complete(..)), even during inference. + // + // It also gives the user access to more information about the type, + // prior to finalization. + if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) { + existing + .index + .set(Bound::Complete(if let Bound::Sum(..) = existing_bound { + Final::sum(data1, data2) + } else { + Final::product(data1, data2) + })); + } + Ok(()) + } + (x, y) => Err(Error::Bind { + existing_bound: x.shallow_clone(), + new_bound: y.shallow_clone(), + hint, + }), + } + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index b50a9de5..fd5c650d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -149,9 +149,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} mod bound_mutex { - use super::{Bound, CompleteBound, Error, Final}; + use super::Bound; use std::fmt; - use std::sync::{Arc, Mutex}; + use std::sync::Mutex; /// Source or target type of a Simplicity expression pub struct BoundMutex { @@ -184,81 +184,6 @@ mod bound_mutex { ); *lock = new; } - - pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { - let existing_bound = self.get(); - let bind_error = || Error::Bind { - existing_bound: existing_bound.shallow_clone(), - new_bound: bound.shallow_clone(), - hint, - }; - - match (&existing_bound, &bound) { - // Binding a free type to anything is a no-op - (_, Bound::Free(_)) => Ok(()), - // Free types are simply dropped and replaced by the new bound - (Bound::Free(_), _) => { - // Free means non-finalized, so set() is ok. - self.set(bound); - Ok(()) - } - // Binding complete->complete shouldn't ever happen, but if so, we just - // compare the two types and return a pass/fail - (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => { - if existing_final == new_final { - Ok(()) - } else { - Err(bind_error()) - } - } - // Binding an incomplete to a complete type requires recursion. - (Bound::Complete(complete), incomplete) - | (incomplete, Bound::Complete(complete)) => { - match (complete.bound(), incomplete) { - // A unit might match a Bound::Free(..) or a Bound::Complete(..), - // and both cases were handled above. So this is an error. - (CompleteBound::Unit, _) => Err(bind_error()), - ( - CompleteBound::Product(ref comp1, ref comp2), - Bound::Product(ref ty1, ref ty2), - ) - | ( - CompleteBound::Sum(ref comp1, ref comp2), - Bound::Sum(ref ty1, ref ty2), - ) => { - ty1.bind(Bound::Complete(Arc::clone(comp1)), hint)?; - ty2.bind(Bound::Complete(Arc::clone(comp2)), hint) - } - _ => Err(bind_error()), - } - } - (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2)) - | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => { - x1.unify(y1, hint)?; - x2.unify(y2, hint)?; - // This type was not complete, but it may be after unification, giving us - // an opportunity to finaliize it. We do this eagerly to make sure that - // "complete" (no free children) is always equivalent to "finalized" (the - // bound field having variant Bound::Complete(..)), even during inference. - // - // It also gives the user access to more information about the type, - // prior to finalization. - if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) { - self.set(Bound::Complete(if let Bound::Sum(..) = bound { - Final::sum(data1, data2) - } else { - Final::product(data1, data2) - })); - } - Ok(()) - } - (x, y) => Err(Error::Bind { - existing_bound: x.shallow_clone(), - new_bound: y.shallow_clone(), - hint, - }), - } - } } } @@ -391,24 +316,6 @@ impl Type { self.clone() } - /// Binds the type to a given bound. If this fails, attach the provided - /// hint to the error. - /// - /// Fails if the type has an existing incompatible bound. - fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { - let root = self.bound.root(); - root.bind(bound, hint) - } - - /// Unify the type with another one. - /// - /// Fails if the bounds on the two types are incompatible - fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> { - self.bound.unify(&other.bound, |x_bound, y_bound| { - x_bound.bind(self.ctx.get(y_bound), hint) - }) - } - /// Accessor for this type's bound pub fn bound(&self) -> Bound { self.ctx.get(&self.bound.root())