Skip to content

Commit

Permalink
Add CachingShardStore for batching writes to a backend ShardStore
Browse files Browse the repository at this point in the history
  • Loading branch information
str4d committed Jul 14, 2023
1 parent 84d652d commit ad177e7
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 0 deletions.
216 changes: 216 additions & 0 deletions shardtree/src/caching.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
//! Implementation of an in-memory shard store with persistence.

use std::convert::Infallible;
use std::fmt;

use incrementalmerkletree::Address;

use crate::memory::MemoryShardStore;
use crate::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore};

#[derive(Debug)]
enum Action<C> {
Truncate(Address),
RemoveCheckpoint(C),
TruncateCheckpoints(C),
}

/// An implementation of [`ShardStore`] that caches all state in memory.
///
/// Cache state is flushed to the backend via [`Self::flush`]. Dropping will instead drop
/// the cached state and not make any changes to the backend.
#[derive(Debug)]
pub struct CachingShardStore<S>
where
S: ShardStore,
S::H: Clone,
S::CheckpointId: Clone + Ord,
S::Error: fmt::Display,
{
backend: S,
cache: MemoryShardStore<S::H, S::CheckpointId>,
deferred_actions: Vec<Action<S::CheckpointId>>,
}

impl<S> CachingShardStore<S>
where
S: ShardStore,
S::H: Clone,
S::CheckpointId: Clone + Ord,
S::Error: fmt::Display,
{
/// Loads a `CachingShardStore` from the given backend.
pub fn load(mut backend: S) -> Result<Self, S::Error> {
let mut cache = MemoryShardStore::empty();

for shard_root in backend.get_shard_roots()? {
let _ = cache.put_shard(backend.get_shard(shard_root)?.expect("known address"));
}
let _ = cache.put_cap(backend.get_cap()?);

backend.with_checkpoints(backend.checkpoint_count()?, |checkpoint_id, checkpoint| {
Ok(cache
.add_checkpoint(checkpoint_id.clone(), checkpoint.clone())
.unwrap())
})?;

Ok(Self {
backend,
cache,
deferred_actions: vec![],
})
}

/// Flushes the current cache state to the backend and returns it.
pub fn flush(mut self) -> Result<S, S::Error> {
for action in &self.deferred_actions {
match action {
Action::Truncate(from) => self.backend.truncate(*from),
Action::RemoveCheckpoint(checkpoint_id) => {
self.backend.remove_checkpoint(checkpoint_id)
}
Action::TruncateCheckpoints(checkpoint_id) => {
self.backend.truncate_checkpoints(checkpoint_id)
}
}?;
}
self.deferred_actions.clear();

for shard_root in self.cache.get_shard_roots().unwrap() {
self.backend.put_shard(
self.cache
.get_shard(shard_root)
.unwrap()
.expect("known address"),
)?;
}
self.backend.put_cap(self.cache.get_cap().unwrap())?;

let mut checkpoints = Vec::with_capacity(self.cache.checkpoint_count().unwrap());
self.cache
.with_checkpoints(
self.cache.checkpoint_count().unwrap(),
|checkpoint_id, checkpoint| {
checkpoints.push((checkpoint_id.clone(), checkpoint.clone()));
Ok(())
},
)
.unwrap();
for (checkpoint_id, checkpoint) in checkpoints {
self.backend.add_checkpoint(checkpoint_id, checkpoint)?;
}

Ok(self.backend)
}
}

impl<S> ShardStore for CachingShardStore<S>
where
S: ShardStore,
S::H: Clone,
S::CheckpointId: Clone + Ord,
S::Error: fmt::Display,
{
type H = S::H;
type CheckpointId = S::CheckpointId;
type Error = Infallible;

fn get_shard(
&self,
shard_root: Address,
) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
self.cache.get_shard(shard_root)
}

fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
self.cache.last_shard()
}

fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
self.cache.put_shard(subtree)
}

fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
self.cache.get_shard_roots()
}

fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
self.deferred_actions.push(Action::Truncate(from));
self.cache.truncate(from)
}

fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
self.cache.get_cap()
}

fn put_cap(&mut self, cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
self.cache.put_cap(cap)
}

fn add_checkpoint(
&mut self,
checkpoint_id: Self::CheckpointId,
checkpoint: Checkpoint,
) -> Result<(), Self::Error> {
self.cache.add_checkpoint(checkpoint_id, checkpoint)
}

fn checkpoint_count(&self) -> Result<usize, Self::Error> {
self.cache.checkpoint_count()
}

fn get_checkpoint(
&self,
checkpoint_id: &Self::CheckpointId,
) -> Result<Option<Checkpoint>, Self::Error> {
self.cache.get_checkpoint(checkpoint_id)
}

fn get_checkpoint_at_depth(
&self,
checkpoint_depth: usize,
) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
self.cache.get_checkpoint_at_depth(checkpoint_depth)
}

fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
self.cache.min_checkpoint_id()
}

fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
self.cache.max_checkpoint_id()
}

fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
{
self.cache.with_checkpoints(limit, callback)
}

fn update_checkpoint_with<F>(
&mut self,
checkpoint_id: &Self::CheckpointId,
update: F,
) -> Result<bool, Self::Error>
where
F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
{
self.cache.update_checkpoint_with(checkpoint_id, update)
}

fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
self.deferred_actions
.push(Action::RemoveCheckpoint(checkpoint_id.clone()));
self.cache.remove_checkpoint(checkpoint_id)
}

fn truncate_checkpoints(
&mut self,
checkpoint_id: &Self::CheckpointId,
) -> Result<(), Self::Error> {
self.deferred_actions
.push(Action::TruncateCheckpoints(checkpoint_id.clone()));
self.cache.truncate_checkpoints(checkpoint_id)
}
}
6 changes: 6 additions & 0 deletions shardtree/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub use self::prunable::{
IncompleteAt, InsertionError, LocatedPrunableTree, PrunableTree, QueryError, RetentionFlags,
};

pub mod caching;
pub mod memory;

#[cfg(any(bench, test, feature = "test-dependencies"))]
Expand Down Expand Up @@ -364,6 +365,11 @@ impl<
}
}

/// Consumes this tree and returns its underlying `ShardStore`.
pub fn into_store(self) -> S {
self.store
}

/// Returns the root address of the tree.
pub fn root_addr() -> Address {
Address::from_parts(Level::from(DEPTH), 0)
Expand Down

0 comments on commit ad177e7

Please sign in to comment.