Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mega pr, do not merge] types: rewrite inference API to use a slab allocator for type bounds #228

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions jets-bench/benches/elements/data_structures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
use bitcoin::secp256k1;
use elements::Txid;
use rand::{thread_rng, RngCore};
pub use simplicity::hashes::sha256;
use simplicity::{
bitcoin, elements, hashes::Hash, hex::FromHex, types::Type, BitIter, Error, Value,
bitcoin, elements, hashes::Hash, hex::FromHex, types::{self, Type}, BitIter, Error, Value,
};
use std::sync::Arc;

Expand Down Expand Up @@ -57,7 +56,8 @@ pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result<Arc<Value>, Erro
assert!(n < 16);
assert!(v.len() < (1 << (n + 1)));
let mut iter = BitIter::new(v.iter().copied());
let types = Type::powers_of_two(n); // size n + 1
let ctx = types::Context::new();
let types = Type::powers_of_two(&ctx, n); // size n + 1
let mut res = None;
while n > 0 {
let v = if v.len() >= (1 << (n + 1)) {
Expand Down
6 changes: 3 additions & 3 deletions jets-bench/benches/elements/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ impl ElementsBenchEnvType {
}

fn jet_arrow(jet: Elements) -> (Arc<types::Final>, Arc<types::Final>) {
let src_ty = jet.source_ty().to_type().final_data().unwrap();
let tgt_ty = jet.target_ty().to_type().final_data().unwrap();
let src_ty = jet.source_ty().to_final();
let tgt_ty = jet.target_ty().to_final();
(src_ty, tgt_ty)
}

Expand Down Expand Up @@ -302,7 +302,7 @@ fn bench(c: &mut Criterion) {
let keypair = bitcoin::key::Keypair::new(&secp_ctx, &mut thread_rng());
let xpk = bitcoin::key::XOnlyPublicKey::from_keypair(&keypair);

let msg = bitcoin::secp256k1::Message::from_slice(&rand::random::<[u8; 32]>()).unwrap();
let msg = bitcoin::secp256k1::Message::from_digest_slice(&rand::random::<[u8; 32]>()).unwrap();
let sig = secp_ctx.sign_schnorr(&msg, &keypair);
let xpk_value = Value::u256_from_slice(&xpk.0.serialize());
let sig_value = Value::u512_from_slice(sig.as_ref());
Expand Down
3 changes: 2 additions & 1 deletion src/bit_encoding/bitwriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ mod tests {
use super::*;
use crate::jet::Core;
use crate::node::CoreConstructible;
use crate::types;
use crate::ConstructNode;
use std::sync::Arc;

#[test]
fn vec() {
let program = Arc::<ConstructNode<Core>>::unit();
let program = Arc::<ConstructNode<Core>>::unit(&types::Context::new());
let _ = write_to_vec(|w| program.encode(w));
}

Expand Down
14 changes: 8 additions & 6 deletions src/bit_encoding/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::node::{
ConstructNode, CoreConstructible, DisconnectConstructible, JetConstructible, NoWitness,
WitnessConstructible,
};
use crate::types;
use crate::{BitIter, FailEntropy, Value};
use std::collections::HashSet;
use std::sync::Arc;
Expand Down Expand Up @@ -178,6 +179,7 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
return Err(Error::TooManyNodes(len));
}

let inference_context = types::Context::new();
let mut nodes = Vec::with_capacity(len);
for _ in 0..len {
let new_node = decode_node(bits, nodes.len())?;
Expand All @@ -195,8 +197,8 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
}

let new = match nodes[data.node.0] {
DecodeNode::Unit => Node(ArcNode::unit()),
DecodeNode::Iden => Node(ArcNode::iden()),
DecodeNode::Unit => Node(ArcNode::unit(&inference_context)),
DecodeNode::Iden => Node(ArcNode::iden(&inference_context)),
DecodeNode::InjL(i) => Node(ArcNode::injl(converted[i].get()?)),
DecodeNode::InjR(i) => Node(ArcNode::injr(converted[i].get()?)),
DecodeNode::Take(i) => Node(ArcNode::take(converted[i].get()?)),
Expand All @@ -222,16 +224,16 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
converted[i].get()?,
&Some(Arc::clone(converted[j].get()?)),
)?),
DecodeNode::Witness => Node(ArcNode::witness(NoWitness)),
DecodeNode::Fail(entropy) => Node(ArcNode::fail(entropy)),
DecodeNode::Witness => Node(ArcNode::witness(&inference_context, NoWitness)),
DecodeNode::Fail(entropy) => Node(ArcNode::fail(&inference_context, entropy)),
DecodeNode::Hidden(cmr) => {
if !hidden_set.insert(cmr) {
return Err(Error::SharingNotMaximal);
}
Hidden(cmr)
}
DecodeNode::Jet(j) => Node(ArcNode::jet(j)),
DecodeNode::Word(ref w) => Node(ArcNode::const_word(Arc::clone(w))),
DecodeNode::Jet(j) => Node(ArcNode::jet(&inference_context, j)),
DecodeNode::Word(ref w) => Node(ArcNode::const_word(&inference_context, Arc::clone(w))),
};
converted.push(new);
}
Expand Down
98 changes: 60 additions & 38 deletions src/human_encoding/named_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::node::{
self, Commit, CommitData, CommitNode, Converter, Inner, NoDisconnect, NoWitness, Node, Witness,
WitnessData,
};
use crate::node::{Construct, ConstructData, Constructible};
use crate::node::{Construct, ConstructData, Constructible as _, CoreConstructible as _};
use crate::types;
use crate::types::arrow::{Arrow, FinalArrow};
use crate::{encode, Value, WitnessNode};
Expand Down Expand Up @@ -113,11 +113,12 @@ impl<J: Jet> NamedCommitNode<J> {
witness: &HashMap<Arc<str>, Arc<Value>>,
disconnect: &HashMap<Arc<str>, Arc<NamedCommitNode<J>>>,
) -> Arc<WitnessNode<J>> {
struct Populator<'a, J: Jet>(
&'a HashMap<Arc<str>, Arc<Value>>,
&'a HashMap<Arc<str>, Arc<NamedCommitNode<J>>>,
PhantomData<J>,
);
struct Populator<'a, J: Jet> {
witness_map: &'a HashMap<Arc<str>, Arc<Value>>,
disconnect_map: &'a HashMap<Arc<str>, Arc<NamedCommitNode<J>>>,
inference_context: types::Context,
phantom: PhantomData<J>,
}

impl<'a, J: Jet> Converter<Named<Commit<J>>, Witness<J>> for Populator<'a, J> {
type Error = ();
Expand All @@ -133,7 +134,7 @@ impl<J: Jet> NamedCommitNode<J> {
// Which nodes are pruned is not known when this code is executed.
// If an unpruned node is unpopulated, then there will be an error
// during the finalization.
Ok(self.0.get(name).cloned())
Ok(self.witness_map.get(name).cloned())
}

fn convert_disconnect(
Expand All @@ -152,16 +153,17 @@ impl<J: Jet> NamedCommitNode<J> {
// We keep the missing disconnected branches empty.
// Like witness nodes (see above), disconnect nodes may be pruned later.
// The finalization will detect missing branches and throw an error.
let maybe_commit = self.1.get(hole_name);
// FIXME: Recursive call of to_witness_node
// We cannot introduce a stack
// because we are implementing methods of the trait Converter
// which are used Marker::convert().
let maybe_commit = self.disconnect_map.get(hole_name);
// FIXME: recursive call to convert
// We cannot introduce a stack because we are implementing the Converter
// trait and do not have access to the actual algorithm used for conversion
// in order to save its state.
//
// OTOH, if a user writes a program with so many disconnected expressions
// that there is a stack overflow, it's his own fault :)
// This would fail in a fuzz test.
let witness = maybe_commit.map(|commit| commit.to_witness_node(self.0, self.1));
// This may fail in a fuzz test.
let witness = maybe_commit
.map(|commit| commit.convert::<InternalSharing, _, _>(self).unwrap());
Ok(witness)
}
}
Expand All @@ -179,12 +181,18 @@ impl<J: Jet> NamedCommitNode<J> {
let inner = inner
.map(|node| node.cached_data())
.map_witness(|maybe_value| maybe_value.clone());
Ok(WitnessData::from_inner(inner).expect("types are already finalized"))
Ok(WitnessData::from_inner(&self.inference_context, inner)
.expect("types are already finalized"))
}
}

self.convert::<InternalSharing, _, _>(&mut Populator(witness, disconnect, PhantomData))
.unwrap()
self.convert::<InternalSharing, _, _>(&mut Populator {
witness_map: witness,
disconnect_map: disconnect,
inference_context: types::Context::new(),
phantom: PhantomData,
})
.unwrap()
}

/// Encode a Simplicity expression to bits without any witness data
Expand Down Expand Up @@ -239,13 +247,15 @@ pub struct NamedConstructData<J> {
impl<J: Jet> NamedConstructNode<J> {
/// Construct a named construct node from parts.
pub fn new(
inference_context: &types::Context,
name: Arc<str>,
position: Position,
user_source_types: Arc<[types::Type]>,
user_target_types: Arc<[types::Type]>,
inner: node::Inner<Arc<Self>, J, Arc<Self>, WitnessOrHole>,
) -> Result<Self, types::Error> {
let construct_data = ConstructData::from_inner(
inference_context,
inner
.as_ref()
.map(|data| &data.cached_data().internal)
Expand Down Expand Up @@ -289,6 +299,11 @@ impl<J: Jet> NamedConstructNode<J> {
self.cached_data().internal.arrow()
}

/// Accessor for the node's type inference context.
pub fn inference_context(&self) -> &types::Context {
self.cached_data().internal.inference_context()
}

/// Finalizes the types of the underlying [`crate::ConstructNode`].
pub fn finalize_types_main(&self) -> Result<Arc<NamedCommitNode<J>>, ErrorSet> {
self.finalize_types_inner(true)
Expand Down Expand Up @@ -380,17 +395,23 @@ impl<J: Jet> NamedConstructNode<J> {
.map_disconnect(|_| &NoDisconnect)
.copy_witness();

let ctx = data.node.inference_context();

if !self.for_main {
// For non-`main` fragments, treat the ascriptions as normative, and apply them
// before finalizing the type.
let arrow = data.node.arrow();
for ty in data.node.cached_data().user_source_types.as_ref() {
if let Err(e) = arrow.source.unify(ty, "binding source type annotation") {
if let Err(e) =
ctx.unify(&arrow.source, ty, "binding source type annotation")
{
self.errors.add(data.node.position(), e);
}
}
for ty in data.node.cached_data().user_target_types.as_ref() {
if let Err(e) = arrow.target.unify(ty, "binding target type annotation") {
if let Err(e) =
ctx.unify(&arrow.target, ty, "binding target type annotation")
{
self.errors.add(data.node.position(), e);
}
}
Expand All @@ -407,19 +428,19 @@ impl<J: Jet> NamedConstructNode<J> {
if self.for_main {
// For `main`, only apply type ascriptions *after* inference has completely
// determined the type.
let source_bound =
types::Bound::Complete(Arc::clone(&commit_data.arrow().source));
let source_ty = types::Type::from(source_bound);
let source_ty =
types::Type::complete(ctx, Arc::clone(&commit_data.arrow().source));
for ty in data.node.cached_data().user_source_types.as_ref() {
if let Err(e) = source_ty.unify(ty, "binding source type annotation") {
if let Err(e) = ctx.unify(&source_ty, ty, "binding source type annotation")
{
self.errors.add(data.node.position(), e);
}
}
let target_bound =
types::Bound::Complete(Arc::clone(&commit_data.arrow().target));
let target_ty = types::Type::from(target_bound);
let target_ty =
types::Type::complete(ctx, Arc::clone(&commit_data.arrow().target));
for ty in data.node.cached_data().user_target_types.as_ref() {
if let Err(e) = target_ty.unify(ty, "binding target type annotation") {
if let Err(e) = ctx.unify(&target_ty, ty, "binding target type annotation")
{
self.errors.add(data.node.position(), e);
}
}
Expand All @@ -440,22 +461,23 @@ impl<J: Jet> NamedConstructNode<J> {
};

if for_main {
let unit_ty = types::Type::unit();
let ctx = self.inference_context();
let unit_ty = types::Type::unit(ctx);
if self.cached_data().user_source_types.is_empty() {
if let Err(e) = self
.arrow()
.source
.unify(&unit_ty, "setting root source to unit")
{
if let Err(e) = ctx.unify(
&self.arrow().source,
&unit_ty,
"setting root source to unit",
) {
finalizer.errors.add(self.position(), e);
}
}
if self.cached_data().user_target_types.is_empty() {
if let Err(e) = self
.arrow()
.target
.unify(&unit_ty, "setting root source to unit")
{
if let Err(e) = ctx.unify(
&self.arrow().target,
&unit_ty,
"setting root target to unit",
) {
finalizer.errors.add(self.position(), e);
}
}
Expand Down
29 changes: 19 additions & 10 deletions src/human_encoding/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,25 @@ pub enum Type {

impl Type {
/// Convert to a Simplicity type
pub fn reify(self) -> types::Type {
pub fn reify(self, ctx: &types::Context) -> types::Type {
match self {
Type::Name(s) => types::Type::free(s),
Type::One => types::Type::unit(),
Type::Two => types::Type::sum(types::Type::unit(), types::Type::unit()),
Type::Product(left, right) => types::Type::product(left.reify(), right.reify()),
Type::Sum(left, right) => types::Type::sum(left.reify(), right.reify()),
Type::TwoTwoN(n) => types::Type::two_two_n(n as usize), // cast OK as we are only using tiny numbers
Type::Name(s) => types::Type::free(ctx, s),
Type::One => types::Type::unit(ctx),
Type::Two => {
let unit_ty = types::Type::unit(ctx);
types::Type::sum(ctx, unit_ty.shallow_clone(), unit_ty)
}
Type::Product(left, right) => {
let left = left.reify(ctx);
let right = right.reify(ctx);
types::Type::product(ctx, left, right)
}
Type::Sum(left, right) => {
let left = left.reify(ctx);
let right = right.reify(ctx);
types::Type::sum(ctx, left, right)
}
Type::TwoTwoN(n) => types::Type::two_two_n(ctx, n as usize), // cast OK as we are only using tiny numbers
}
}
}
Expand Down Expand Up @@ -633,9 +644,7 @@ fn grammar<J: Jet + 'static>() -> Grammar<Ast<J>> {
Error::BadWordLength { bit_length },
));
}
let ty = types::Type::two_two_n(bit_length.trailing_zeros() as usize)
.final_data()
.unwrap();
let ty = types::Final::two_two_n(bit_length.trailing_zeros() as usize);
// unwrap ok here since literally every sequence of bits is a valid
// value for the given type
let value = iter.read_value(&ty).unwrap();
Expand Down
8 changes: 5 additions & 3 deletions src/human_encoding/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod ast;
use crate::dag::{Dag, DagLike, InternalSharing};
use crate::jet::Jet;
use crate::node;
use crate::types::Type;
use crate::types::{self, Type};
use std::collections::HashMap;
use std::mem;
use std::sync::atomic::{AtomicUsize, Ordering};
Expand Down Expand Up @@ -181,6 +181,7 @@ pub fn parse<J: Jet + 'static>(
program: &str,
) -> Result<HashMap<Arc<str>, Arc<NamedCommitNode<J>>>, ErrorSet> {
let mut errors = ErrorSet::new();
let inference_context = types::Context::new();
// **
// Step 1: Read expressions into HashMap, checking for dupes and illegal names.
// **
Expand All @@ -205,10 +206,10 @@ pub fn parse<J: Jet + 'static>(
}
}
if let Some(ty) = line.arrow.0 {
entry.add_source_type(ty.reify());
entry.add_source_type(ty.reify(&inference_context));
}
if let Some(ty) = line.arrow.1 {
entry.add_target_type(ty.reify());
entry.add_target_type(ty.reify(&inference_context));
}
}

Expand Down Expand Up @@ -485,6 +486,7 @@ pub fn parse<J: Jet + 'static>(
.unwrap_or_else(|| Arc::from(namer.assign_name(inner.as_ref()).as_str()));

let node = NamedConstructNode::new(
&inference_context,
Arc::clone(&name),
data.node.position,
Arc::clone(&data.node.user_source_types),
Expand Down
Loading
Loading