diff --git a/shardtree/src/batch.rs b/shardtree/src/batch.rs index 632cb5f..c4692aa 100644 --- a/shardtree/src/batch.rs +++ b/shardtree/src/batch.rs @@ -1,6 +1,6 @@ //! Helpers for inserting many leaves into a tree at once. -use std::{collections::BTreeMap, fmt, ops::Range, sync::Arc}; +use std::{collections::BTreeMap, fmt, ops::Range}; use incrementalmerkletree::{Address, Hashable, Level, Position, Retention}; use tracing::trace; @@ -8,7 +8,7 @@ use tracing::trace; use crate::{ error::{InsertionError, ShardTreeError}, store::{Checkpoint, ShardStore}, - IncompleteAt, LocatedPrunableTree, LocatedTree, Node, RetentionFlags, ShardTree, Tree, + IncompleteAt, LocatedPrunableTree, LocatedTree, RetentionFlags, ShardTree, Tree, }; impl< @@ -203,9 +203,7 @@ impl LocatedPrunableTree { let rflags = RetentionFlags::from(retention); let mut subtree = LocatedTree { root_addr: Address::from(position), - root: Tree(Node::Leaf { - value: (value.clone(), rflags), - }), + root: Tree::leaf((value.clone(), rflags)), }; if position.is_right_child() { @@ -268,11 +266,7 @@ fn unite( root: if lroot.root_addr.level() < prune_below { Tree::unite(lroot.root_addr.level(), None, lroot.root, rroot.root) } else { - Tree(Node::Parent { - ann: None, - left: Arc::new(lroot.root), - right: Arc::new(rroot.root), - }) + Tree::parent(None, lroot.root, rroot.root) }, } } @@ -297,7 +291,7 @@ fn combine_with_empty( }); let sibling = LocatedTree { root_addr: sibling_addr, - root: Tree(Node::Nil), + root: Tree::empty(), }; let (lroot, rroot) = if root.root_addr.is_left_child() { (root, sibling) diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index 7fe6cf9..398b24b 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -424,7 +424,7 @@ impl< l_addr.level(), ann.clone(), new_left, - Tree(Node::Nil), + Tree::empty(), ), pos, ) @@ -444,9 +444,7 @@ impl< ) } Node::Leaf { value: (h, r) } => Some(( - Tree(Node::Leaf { - value: (h.clone(), *r | RetentionFlags::CHECKPOINT), - }), + Tree::leaf((h.clone(), *r | RetentionFlags::CHECKPOINT)), root_addr.max_position(), )), Node::Nil | Node::Pruned => None, @@ -822,6 +820,10 @@ impl< (None, None) => unreachable!(), }; + // We don't use the `Tree::parent` constructor here, because it + // creates `Arc`s for the child nodes internally, but if we don't + // have a new child then we want to use the `Arc` for the existing + // child. let new_parent = Tree(Node::Parent { ann: new_left .as_ref() diff --git a/shardtree/src/prunable.rs b/shardtree/src/prunable.rs index 5085e31..d230093 100644 --- a/shardtree/src/prunable.rs +++ b/shardtree/src/prunable.rs @@ -246,9 +246,7 @@ impl PrunableTree { (Tree(Node::Leaf { value: vl }), Tree(Node::Leaf { value: vr })) => { if vl.0 == vr.0 { // Merge the flags together. - Ok(Tree(Node::Leaf { - value: (vl.0, vl.1 | vr.1), - })) + Ok(Tree::leaf((vl.0, vl.1 | vr.1))) } else { trace!(left = ?vl.0, right = ?vr.0, "Merge conflict for leaves"); Err(addr) @@ -314,7 +312,7 @@ impl PrunableTree { /// `level` must be the level of the two nodes that are being joined. pub(crate) fn unite(level: Level, ann: Option>, left: Self, right: Self) -> Self { match (left, right) { - (Tree(Node::Nil), Tree(Node::Nil)) => Tree(Node::Nil), + (Tree(Node::Nil), Tree(Node::Nil)) => Tree::empty(), (Tree(Node::Leaf { value: lv }), Tree(Node::Leaf { value: rv })) // we can prune right-hand leaves that are not marked or reference leaves; if a // leaf is a checkpoint then that information will be propagated to the replacement @@ -322,18 +320,12 @@ impl PrunableTree { if lv.1 == RetentionFlags::EPHEMERAL && (rv.1 & (RetentionFlags::MARKED | RetentionFlags::REFERENCE)) == RetentionFlags::EPHEMERAL => { - Tree( - Node::Leaf { - value: (H::combine(level, &lv.0, &rv.0), rv.1), - }, - ) + Tree::leaf((H::combine(level, &lv.0, &rv.0), rv.1)) } - (left, right) => Tree( - Node::Parent { + (left, right) => Tree::parent( ann, - left: Arc::new(left), - right: Arc::new(right), - }, + left, + right, ), } } @@ -514,7 +506,7 @@ impl LocatedPrunableTree { // to the left to truncate the left child and then reconstruct the // node with `Nil` as the right sibling go(position, l_child, left.as_ref()).map(|left| { - Tree::unite(l_child.level(), ann.clone(), left, Tree(Node::Nil)) + Tree::unite(l_child.level(), ann.clone(), left, Tree::empty()) }) } else { // we are truncating within the range of the right node, so recurse @@ -584,37 +576,32 @@ impl LocatedPrunableTree { // In the case that we are replacing a node entirely, we need to extend the // subtree up to the level of the node being replaced, adding Nil siblings // and recording the presence of those incomplete nodes when necessary - let replacement = |ann: Option>, - mut node: LocatedPrunableTree, - pruned: bool| { - // construct the replacement node bottom-up - let mut incomplete = vec![]; - while node.root_addr.level() < root_addr.level() { - incomplete.push(IncompleteAt { - address: node.root_addr.sibling(), - required_for_witness: contains_marked, - }); - let empty = Arc::new(Tree(if pruned { Node::Pruned } else { Node::Nil })); - let full = Arc::new(node.root); - node = LocatedTree { - root_addr: node.root_addr.parent(), - root: if node.root_addr.is_right_child() { - Tree(Node::Parent { - ann: None, - left: empty, - right: full, - }) + let replacement = + |ann: Option>, mut node: LocatedPrunableTree, pruned: bool| { + // construct the replacement node bottom-up + let mut incomplete = vec![]; + while node.root_addr.level() < root_addr.level() { + incomplete.push(IncompleteAt { + address: node.root_addr.sibling(), + required_for_witness: contains_marked, + }); + let empty = if pruned { + Tree::empty_pruned() } else { - Tree(Node::Parent { - ann: None, - left: full, - right: empty, - }) - }, - }; - } - (node.root.reannotate_root(ann), incomplete) - }; + Tree::empty() + }; + let full = node.root; + node = LocatedTree { + root_addr: node.root_addr.parent(), + root: if node.root_addr.is_right_child() { + Tree::parent(None, empty, full) + } else { + Tree::parent(None, full, empty) + }, + }; + } + (node.root.reannotate_root(ann), incomplete) + }; match into { Tree(Node::Nil) => Ok(replacement(None, subtree, false)), @@ -936,16 +923,14 @@ impl LocatedPrunableTree { // a partially-pruned branch, and if it's a marked node then it will // be a level-0 leaf. match to_clear { - [(_, flags)] => Tree(Node::Leaf { - value: (h.clone(), *r & !*flags), - }), + [(_, flags)] => Tree::leaf((h.clone(), *r & !*flags)), _ => { panic!("Tree state inconsistent with checkpoints."); } } } - Node::Nil => Tree(Node::Nil), - Node::Pruned => Tree(Node::Pruned), + Node::Nil => Tree::empty(), + Node::Pruned => Tree::empty_pruned(), } } } diff --git a/shardtree/src/store/memory.rs b/shardtree/src/store/memory.rs index aa16f52..bf3d2fd 100644 --- a/shardtree/src/store/memory.rs +++ b/shardtree/src/store/memory.rs @@ -6,7 +6,7 @@ use std::convert::{Infallible, TryFrom}; use incrementalmerkletree::Address; use super::{Checkpoint, ShardStore}; -use crate::{LocatedPrunableTree, LocatedTree, Node, PrunableTree, Tree}; +use crate::{LocatedPrunableTree, LocatedTree, PrunableTree, Tree}; /// An implementation of [`ShardStore`] that stores all state in memory. /// @@ -54,7 +54,7 @@ impl ShardStore for MemoryShardStore { { self.shards.push(LocatedTree { root_addr: Address::from_parts(subtree_addr.level(), subtree_idx), - root: Tree(Node::Nil), + root: Tree::empty(), }) } diff --git a/shardtree/src/testing.rs b/shardtree/src/testing.rs index 1cd3046..6a56d0f 100644 --- a/shardtree/src/testing.rs +++ b/shardtree/src/testing.rs @@ -30,22 +30,15 @@ where A::Value: Clone + 'static, V::Value: Clone + 'static, { - let leaf = prop_oneof![ - Just(Tree(Node::Nil)), - arb_leaf.prop_map(|value| Tree(Node::Leaf { value })) - ]; + let leaf = prop_oneof![Just(Tree::empty()), arb_leaf.prop_map(Tree::leaf)]; leaf.prop_recursive(depth, size, 2, move |inner| { (arb_annotation.clone(), inner.clone(), inner).prop_map(|(ann, left, right)| { - Tree(if left.is_nil() && right.is_nil() { - Node::Nil + if left.is_nil() && right.is_nil() { + Tree::empty() } else { - Node::Parent { - ann, - left: Arc::new(left), - right: Arc::new(right), - } - }) + Tree::parent(ann, left, right) + } }) }) } diff --git a/shardtree/src/tree.rs b/shardtree/src/tree.rs index d0f8913..6682c7b 100644 --- a/shardtree/src/tree.rs +++ b/shardtree/src/tree.rs @@ -315,7 +315,7 @@ impl LocatedTree { pub fn empty(root_addr: Address) -> Self { Self { root_addr, - root: Tree(Node::Nil), + root: Tree::empty(), } } @@ -324,7 +324,7 @@ impl LocatedTree { pub fn with_root_value(root_addr: Address, value: V) -> Self { Self { root_addr, - root: Tree(Node::Leaf { value }), + root: Tree::leaf(value), } } @@ -412,12 +412,10 @@ impl LocatedTree { pub(crate) mod tests { use incrementalmerkletree::{Address, Level}; - use super::{LocatedTree, Node, Tree}; + use super::{LocatedTree, Tree}; pub(crate) fn str_leaf(c: &str) -> Tree { - Tree(Node::Leaf { - value: c.to_string(), - }) + Tree::leaf(c.to_string()) } pub(crate) fn nil() -> Tree {