Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verified snapshot #6

Merged
merged 3 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/prove-and-verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::rc::Rc;
use kairos_trie::{
stored::{
memory_db::MemoryDb,
merkle::{Snapshot, SnapshotBuilder},
merkle::{Snapshot, SnapshotBuilder, VerifiedSnapshot},
Store,
},
DigestHasher, KeyHash, NodeHash, PortableHash, PortableHasher, Transaction, TrieRoot,
Expand All @@ -21,7 +21,7 @@ fn hash(key: &str) -> KeyHash {
KeyHash::from_bytes(&hasher.finalize_reset())
}

fn apply_operations(txn: &mut Transaction<impl Store<u64>, u64>, operations: &[Ops]) {
fn apply_operations(txn: &mut Transaction<impl Store<Value = u64>>, operations: &[Ops]) {
for op in operations {
match op {
Ops::Add(key, value) => {
Expand Down Expand Up @@ -72,7 +72,9 @@ fn verifier(
) -> TrieRoot<NodeHash> {
let hasher = &mut DigestHasher::<Sha256>::default();

let mut txn = Transaction::from_snapshot(snapshot).unwrap();
let mut txn = Transaction::from_verified_snapshot(
VerifiedSnapshot::verify_snapshot(snapshot, hasher).unwrap(),
);

let pre_batch_trie_root = txn.calc_root_hash(hasher).unwrap();
// Assert that the trie started the transaction with the correct root hash.
Expand Down
34 changes: 25 additions & 9 deletions src/stored.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,30 @@ use alloc::{rc::Rc, sync::Arc};

use crate::{
transaction::nodes::{Branch, Leaf, Node},
NodeHash, PortableHasher,
NodeHash, PortableHash, PortableHasher,
};

pub type Idx = u32;

pub trait Store<V> {
pub trait Store {
type Error: Display;
type Value: Clone + PortableHash;

fn calc_subtree_hash(
&self,
hasher: &mut impl PortableHasher<32>,
hash_idx: Idx,
) -> Result<NodeHash, Self::Error>;

fn get_node(&self, hash_idx: Idx) -> Result<Node<&Branch<Idx>, &Leaf<V>>, Self::Error>;
fn get_node(
&self,
hash_idx: Idx,
) -> Result<Node<&Branch<Idx>, &Leaf<Self::Value>>, Self::Error>;
}

impl<V, S: Store<V>> Store<V> for &S {
impl<S: Store> Store for &S {
type Error = S::Error;
type Value = S::Value;

#[inline(always)]
fn calc_subtree_hash(
Expand All @@ -38,13 +43,17 @@ impl<V, S: Store<V>> Store<V> for &S {
}

#[inline(always)]
fn get_node(&self, hash_idx: Idx) -> Result<Node<&Branch<Idx>, &Leaf<V>>, Self::Error> {
fn get_node(
&self,
hash_idx: Idx,
) -> Result<Node<&Branch<Idx>, &Leaf<Self::Value>>, Self::Error> {
(**self).get_node(hash_idx)
}
}

impl<V, S: Store<V>> Store<V> for Rc<S> {
impl<S: Store> Store for Rc<S> {
type Error = S::Error;
type Value = S::Value;

#[inline(always)]
fn calc_subtree_hash(
Expand All @@ -56,13 +65,17 @@ impl<V, S: Store<V>> Store<V> for Rc<S> {
}

#[inline(always)]
fn get_node(&self, hash_idx: Idx) -> Result<Node<&Branch<Idx>, &Leaf<V>>, Self::Error> {
fn get_node(
&self,
hash_idx: Idx,
) -> Result<Node<&Branch<Idx>, &Leaf<Self::Value>>, Self::Error> {
(**self).get_node(hash_idx)
}
}

impl<V, S: Store<V>> Store<V> for Arc<S> {
impl<S: Store> Store for Arc<S> {
type Error = S::Error;
type Value = S::Value;

#[inline(always)]
fn calc_subtree_hash(
Expand All @@ -74,7 +87,10 @@ impl<V, S: Store<V>> Store<V> for Arc<S> {
}

#[inline(always)]
fn get_node(&self, hash_idx: Idx) -> Result<Node<&Branch<Idx>, &Leaf<V>>, Self::Error> {
fn get_node(
&self,
hash_idx: Idx,
) -> Result<Node<&Branch<Idx>, &Leaf<Self::Value>>, Self::Error> {
(**self).get_node(hash_idx)
}
}
Expand Down
220 changes: 215 additions & 5 deletions src/stored/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,189 @@ use super::{DatabaseGet, Idx, Node, NodeHash, Store};

type Result<T, E = TrieError> = core::result::Result<T, E>;

/// A snapshot of the merkle trie verified
///
/// Contains visited nodes and unvisited nodes
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct VerifiedSnapshot<S: Store> {
snapshot: S,

/// The root hash of the snapshot is the last hash in the slice.
/// The indexes of each hash match the indexes of nodes in the snapshot.
branch_hashes: Box<[NodeHash]>,
leaf_hashes: Box<[NodeHash]>,
}

impl<S: Store + AsRef<Snapshot<S::Value>>> VerifiedSnapshot<S> {
/// Verify the snapshot by checking that it is well formed and calculating the merkle hashes of all nodes.
/// The merkle hashes are cached such that `calc_subtree_hash` is an O(1) operation for all nodes in the snapshot.
/// In practice this means if a transaction gets, but does not modify a node the hash is never recalculated.
///
/// Storing the hashes of all nodes does increase memory usage,
/// so for some use cases it may be better to use `Snapshot` directly.
/// If you choose to do that, be sure to run `calc_root_hash` before using the snapshot.
#[inline]
pub fn verify_snapshot(snapshot: S, hasher: &mut impl PortableHasher<32>) -> Result<Self> {
let snapshot_ref = snapshot.as_ref();

// Check that the snapshot is well formed.
let _ = snapshot_ref.root_node_idx()?;

let mut leaf_hashes = Vec::with_capacity(snapshot_ref.leaves.len());
let mut branch_hashes = Vec::with_capacity(snapshot_ref.branches.len());

for leaf in snapshot_ref.leaves.iter() {
leaf_hashes.push(leaf.hash_leaf(hasher));
}

let leaf_offset = snapshot_ref.branches.len();
let unvisited_offset = leaf_offset + snapshot_ref.leaves.len();

for (idx, branch) in snapshot_ref.branches.iter().enumerate() {
let hash_of_child = |child| {
if child < leaf_offset {
branch_hashes.get(child).ok_or_else(|| {
format!(
"Invalid snapshot: branch {} has child {},\
child branch index must be less than parent, branches are not in post-order traversal",
idx, child
)
})
} else if child < unvisited_offset {
leaf_hashes.get(child - leaf_offset).ok_or_else(|| {
format!(
"Invalid snapshot: branch {} has child {},\
child leaf does not exist",
idx, child
)
})
} else {
snapshot_ref
.unvisited_nodes
.get(child - unvisited_offset)
.ok_or_else(|| {
format!(
"Invalid snapshot: branch {} has child {},\
child unvisited node does not exist",
idx, child
)
})
}
};

let left_hash = hash_of_child(branch.left as usize)?;
let right_hash = hash_of_child(branch.right as usize)?;

branch_hashes.push(branch.hash_branch(hasher, left_hash, right_hash));
}

Ok(VerifiedSnapshot {
snapshot,
branch_hashes: branch_hashes.into_boxed_slice(),
leaf_hashes: leaf_hashes.into_boxed_slice(),
})
}

#[inline]
pub fn trie_root(&self) -> TrieRoot<NodeRef<S::Value>> {
let snapshot = self.snapshot.as_ref();

if !snapshot.branches.is_empty() {
TrieRoot::Node(NodeRef::Stored(snapshot.branches.len() as Idx - 1))
} else if snapshot.leaves.is_empty() && snapshot.unvisited_nodes.is_empty() {
TrieRoot::Empty
} else {
// We know that the snapshot is valid, because we verified it in `verify_snapshot`.
// If a non-empty snapshot contains no branches, it must have a single leaf or unvisited node.
// Any other combination is invalid.
debug_assert_eq!(snapshot.leaves.len() + snapshot.unvisited_nodes.len(), 1);
TrieRoot::Node(NodeRef::Stored(0))
}
}

/// Returns the merkle root hash of the trie in the snapshot.
/// The hash of all nodes has already been calculated in `VerifiedSnapshot::verify_snapshot`.
/// `trie_root_hash` and `calc_root_hash` are both O(1) operations on a `VerifiedSnapshot`,
/// unlike `Snapshot` and `SnapshotBuilder`.
#[inline]
pub fn trie_root_hash(&self) -> TrieRoot<NodeHash> {
self.branch_hashes
.last()
// Given a valid snapshot: if no branches exist, there can only be one leaf or one unvisited node.
.or_else(|| self.leaf_hashes.first())
.or_else(|| self.snapshot.as_ref().unvisited_nodes.first())
.map_or(TrieRoot::Empty, |hash| TrieRoot::Node(*hash))
}
}

impl<S: Store + AsRef<Snapshot<S::Value>>> Store for VerifiedSnapshot<S> {
type Error = TrieError;
type Value = S::Value;

#[inline]
fn calc_subtree_hash(
&self,
_: &mut impl PortableHasher<32>,
node: Idx,
) -> Result<NodeHash, Self::Error> {
let snapshot = self.snapshot.as_ref();

let idx = node as usize;
let leaf_offset = snapshot.branches.len();
let unvisited_offset = leaf_offset + snapshot.leaves.len();

if let Some(branch) = self.branch_hashes.get(idx) {
Ok(*branch)
} else if let Some(leaf) = self.leaf_hashes.get(idx - leaf_offset) {
Ok(*leaf)
} else if let Some(hash) = snapshot.unvisited_nodes.get(idx - unvisited_offset) {
Ok(*hash)
} else {
Err(format!(
"Invalid arg: node {} does not exist\n\
Snapshot has {} nodes",
idx,
snapshot.branches.len() + snapshot.leaves.len() + snapshot.unvisited_nodes.len(),
)
.into())
}
}

#[inline]
fn get_node(&self, idx: Idx) -> Result<Node<&Branch<Idx>, &Leaf<S::Value>>> {
let snapshot = self.snapshot.as_ref();

let idx = idx as usize;
let leaf_offset = snapshot.branches.len();
let unvisited_offset = leaf_offset + snapshot.leaves.len();

if let Some(branch) = snapshot.branches.get(idx) {
Ok(Node::Branch(branch))
} else if let Some(leaf) = snapshot.leaves.get(idx - leaf_offset) {
Ok(Node::Leaf(leaf))
} else if snapshot
.unvisited_nodes
.get(idx - unvisited_offset)
.is_some()
{
Err(format!(
"Invalid arg: node {idx} is unvisited\n\
get_node can only return visited nodes"
)
.into())
} else {
Err(format!(
"Invalid arg: node {} does not exist\n\
Snapshot has {} nodes",
idx,
snapshot.branches.len() + snapshot.leaves.len() + snapshot.unvisited_nodes.len(),
)
.into())
}
}
}

/// A snapshot of the merkle trie
///
/// Contains visited nodes and unvisited nodes
Expand All @@ -28,7 +211,14 @@ pub struct Snapshot<V> {
unvisited_nodes: Box<[NodeHash]>,
}

impl<V: PortableHash> Snapshot<V> {
impl<V> AsRef<Snapshot<V>> for Snapshot<V> {
#[inline]
fn as_ref(&self) -> &Snapshot<V> {
self
}
}

impl<V: Clone + PortableHash> Snapshot<V> {
#[inline]
pub fn root_node_idx(&self) -> Result<TrieRoot<Idx>> {
// Revist this once https://github.com/rust-lang/rust/issues/37854 is stable
Expand All @@ -39,8 +229,9 @@ impl<V: PortableHash> Snapshot<V> {
) {
// A empty tree
([], [], []) => Ok(TrieRoot::Empty),
// A tree with only one node
([_], [], []) | ([], [_], []) | ([], [], [_]) => Ok(TrieRoot::Node(0)),
// A tree with only one node, it must be a leaf or unvisited node.
// It can't be a branch because branches have children.
([], [_], []) | ([], [], [_]) => Ok(TrieRoot::Node(0)),
(branches, _, _) if !branches.is_empty() => {
Ok(TrieRoot::Node(branches.len() as Idx - 1))
}
Expand Down Expand Up @@ -81,10 +272,28 @@ impl<V: PortableHash> Snapshot<V> {
TrieRoot::Empty => Ok(TrieRoot::Empty),
}
}

/// Verify the snapshot by checking that it is well formed and calculating the merkle hashes of all nodes.
/// This is an alias for `VerifiedSnapshot::verify_snapshot`.
#[inline]
pub fn verify_ref(
&self,
hasher: &mut impl PortableHasher<32>,
) -> Result<VerifiedSnapshot<&Self>> {
VerifiedSnapshot::verify_snapshot(self, hasher)
}

/// Verify the snapshot by checking that it is well formed and calculating the merkle hashes of all nodes.
/// This is an alias for `VerifiedSnapshot::verify_snapshot`.
#[inline]
pub fn verify(self, hasher: &mut impl PortableHasher<32>) -> Result<VerifiedSnapshot<Self>> {
VerifiedSnapshot::verify_snapshot(self, hasher)
}
}

impl<V: PortableHash> Store<V> for Snapshot<V> {
impl<V: Clone + PortableHash> Store for Snapshot<V> {
type Error = TrieError;
type Value = V;

// TODO fix possible stack overflow
// I dislike using an explicit mutable stack.
Expand Down Expand Up @@ -166,8 +375,9 @@ struct SnapshotBuilderInner<Db: 'static, V: 'static> {
nodes: RefCell<Vec<NodeHashMaybeNode<'this, V>>>,
}

impl<Db: DatabaseGet<V>, V: Clone> Store<V> for SnapshotBuilder<Db, V> {
impl<Db: DatabaseGet<V>, V: Clone + PortableHash> Store for SnapshotBuilder<Db, V> {
type Error = TrieError;
type Value = V;

#[inline]
fn calc_subtree_hash(
Expand Down
Loading
Loading