Skip to content

Commit

Permalink
Merge pull request #29 from earthstar-project/efficient_path_decoding
Browse files Browse the repository at this point in the history
Efficient path decoding
  • Loading branch information
sgwilym authored Jul 24, 2024
2 parents b051ae0 + 629c695 commit 14dbe1a
Show file tree
Hide file tree
Showing 22 changed files with 508 additions and 68 deletions.
1 change: 1 addition & 0 deletions data-model/src/encoding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pub mod error;
pub mod max_power;
pub mod parameters;
pub mod relativity;
pub(crate) mod shared_buffers;
pub mod unsigned_int;
114 changes: 85 additions & 29 deletions data-model/src/encoding/relativity.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use core::future::Future;
use core::mem::{size_of, MaybeUninit};
use ufotofu::local_nb::{BulkConsumer, BulkProducer};

use crate::{
Expand All @@ -19,6 +20,9 @@ use crate::{
path::Path,
};


use super::shared_buffers::ScratchSpacePathDecoding;

/// A type that can be used to encode `T` to a bytestring *encoded relative to `R`*.
/// This can be used to create more compact encodings from which `T` can be derived by anyone with `R`.
pub trait RelativeEncodable<R> {
Expand Down Expand Up @@ -62,27 +66,21 @@ impl<const MCL: usize, const MCC: usize, const MPL: usize> RelativeEncodable<Pat
Consumer: BulkConsumer<Item = u8>,
{
let lcp = self.longest_common_prefix(reference);
encode_max_power(lcp.get_component_count(), MCC, consumer).await?;

if lcp.get_component_count() > 0 {
let suffix_components = self.suffix_components(lcp.get_component_count());

// TODO: A more performant version of this.
let lcp_component_count = lcp.get_component_count();
encode_max_power(lcp_component_count, MCC, consumer).await?;

let mut suffix = Path::<MCL, MCC, MPL>::new_empty();
let suffix_component_count = self.get_component_count() - lcp_component_count;
encode_max_power(suffix_component_count, MCC, consumer).await?;

for component in suffix_components {
// We can unwrap here because this suffix is a subset of a valid path.
suffix = suffix.append(component).unwrap();
}

suffix.encode(consumer).await?;
for component in self.suffix_components(lcp_component_count) {
encode_max_power(component.len(), MCL, consumer).await?;

return Ok(());
consumer
.bulk_consume_full_slice(component.as_ref())
.await
.map_err(|f| f.reason)?;
}

self.encode(consumer).await?;

Ok(())
}
}
Expand All @@ -101,33 +99,91 @@ impl<const MCL: usize, const MCC: usize, const MPL: usize> RelativeDecodable<Pat
Producer: BulkProducer<Item = u8>,
Self: Sized,
{
let lcp = decode_max_power(MCC, producer).await?;
let lcp_component_count: usize = decode_max_power(MCC, producer).await?.try_into()?;

if lcp > reference.get_component_count() as u64 {
return Err(DecodeError::InvalidInput);
if lcp_component_count == 0 {
let decoded = Path::<MCL, MCC, MPL>::decode(producer).await?;

// === Necessary to produce canonic encodings. ===
if lcp_component_count
!= decoded
.longest_common_prefix(reference)
.get_component_count()
{
return Err(DecodeError::InvalidInput);
}
// ===============================================

return Ok(decoded);
}

let prefix = reference
.create_prefix(lcp as usize)
.create_prefix(lcp_component_count as usize)
.ok_or(DecodeError::InvalidInput)?;
let suffix = Path::<MCL, MCC, MPL>::decode(producer).await?;

let mut new = prefix;
let mut buf = ScratchSpacePathDecoding::<MCC, MPL>::new();

// Copy the accumulated component lengths of the prefix into the scratch buffer.
let raw_prefix_acc_component_lengths =
&prefix.raw_buf()[size_of::<usize>()..size_of::<usize>() * (lcp_component_count + 1)];
unsafe {
// Safe because len is less than size_of::<usize>() times the MCC, because `prefix` respects the MCC.
buf.set_many_component_accumulated_lengths_from_ne(raw_prefix_acc_component_lengths);
}

for component in suffix.components() {
match new.append(component) {
Ok(appended) => new = appended,
Err(_) => return Err(DecodeError::InvalidInput),
// Copy the raw path data of the prefix into the scratch buffer.
unsafe {
// safe because we just copied the accumulated component lengths for the first `lcp_component_count` components.
MaybeUninit::copy_from_slice(
buf.path_data_until_as_mut(lcp_component_count),
&reference.raw_buf()[size_of::<usize>() * (reference.get_component_count() + 1)
..size_of::<usize>() * (reference.get_component_count() + 1)
+ prefix.get_path_length()],
);
}

let remaining_component_count: usize = decode_max_power(MCC, producer).await?.try_into()?;
let total_component_count = lcp_component_count + remaining_component_count;
if total_component_count > MCC {
return Err(DecodeError::InvalidInput);
}

let mut accumulated_component_length: usize = prefix.get_path_length(); // Always holds the acc length of all components we copied so far.
for i in lcp_component_count..total_component_count {
let component_len: usize = decode_max_power(MCL, producer).await?.try_into()?;
if component_len > MCL {
return Err(DecodeError::InvalidInput);
}

accumulated_component_length += component_len;
if accumulated_component_length > MPL {
return Err(DecodeError::InvalidInput);
}

buf.set_component_accumulated_length(accumulated_component_length, i);

// Decode the component itself into the scratch buffer.
producer
.bulk_overwrite_full_slice_uninit(unsafe {
// Safe because we called set_component_Accumulated_length for all j <= i
buf.path_data_as_mut(i)
})
.await?;
}

let actual_lcp = reference.longest_common_prefix(&new);
let decoded = unsafe { buf.to_path(total_component_count) };

if actual_lcp.get_component_count() != lcp as usize {
// === Necessary to produce canonic encodings. ===
if lcp_component_count
!= decoded
.longest_common_prefix(reference)
.get_component_count()
{
return Err(DecodeError::InvalidInput);
}
// ===============================================

Ok(new)
Ok(decoded)
}
}

Expand Down
132 changes: 132 additions & 0 deletions data-model/src/encoding/shared_buffers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
//! Shared buffers to be reused across many decoding operations.
use core::mem::MaybeUninit;

use bytes::BytesMut;

use crate::path::Path;

/// A memory region to use for decoding paths. Reused between many decodings.
#[derive(Debug)]
pub(crate) struct ScratchSpacePathDecoding<const MCC: usize, const MPL: usize> {
// The i-th usize holds the total lengths of the first i components.
component_accumulated_lengths: [MaybeUninit<usize>; MCC],
path_data: [MaybeUninit<u8>; MPL],
}

impl<const MCC: usize, const MPL: usize> ScratchSpacePathDecoding<MCC, MPL> {
pub fn new() -> Self {
ScratchSpacePathDecoding {
component_accumulated_lengths: MaybeUninit::uninit_array(),
path_data: MaybeUninit::uninit_array(),
}
}

/// Panic if i >= MCC.
pub fn set_component_accumulated_length(
&mut self,
component_accumulated_length: usize,
i: usize,
) {
MaybeUninit::write(
&mut self.component_accumulated_lengths[i],
component_accumulated_length,
);
}

/// # Safety
///
/// UB if length of slice is greater than `size_of::<usize>() * MCC`.
pub unsafe fn set_many_component_accumulated_lengths_from_ne(&mut self, lengths: &[u8]) {
let slice: &mut [MaybeUninit<u8>] = core::slice::from_raw_parts_mut(
self.component_accumulated_lengths[..lengths.len() / size_of::<usize>()].as_mut_ptr()
as *mut MaybeUninit<u8>,
lengths.len(),
);
MaybeUninit::copy_from_slice(slice, lengths);
}

/// # Safety
///
/// Memory must have been initialised with a prior call to set_component_accumulated_length for the same `i`.
unsafe fn get_component_accumulated_length(&self, i: usize) -> usize {
MaybeUninit::assume_init(self.component_accumulated_lengths[i])
}

/// Return a slice of the accumulated component lengths up to but excluding the `i`-th component, encoded as native-endian u8s.
///
/// # Safety
///
/// Memory must have been initialised with prior call to set_component_accumulated_length for all `j <= i`
pub unsafe fn get_accumumulated_component_lengths(&self, i: usize) -> &[u8] {
core::slice::from_raw_parts(
MaybeUninit::slice_assume_init_ref(&self.component_accumulated_lengths[..i]).as_ptr()
as *const u8,
i * size_of::<usize>(),
)
}

/// Return a mutable slice of the i-th path_data.
///
/// # Safety
///
/// Accumulated component lengths for `i` and `i - 1` must have been set (only for `i` if `i == 0`).
pub unsafe fn path_data_as_mut(&mut self, i: usize) -> &mut [MaybeUninit<u8>] {
let start = if i == 0 {
0
} else {
self.get_component_accumulated_length(i - 1)
};
let end = self.get_component_accumulated_length(i);
&mut self.path_data[start..end]
}

/// Return a mutable slice of the path_data up to but excluding the i-th component.
///
/// # Safety
///
/// Accumulated component lengths for `i - 1` must have been set (unless `i == 0`).
pub unsafe fn path_data_until_as_mut(&mut self, i: usize) -> &mut [MaybeUninit<u8>] {
let end = self.get_component_accumulated_length(i - 1);
&mut self.path_data[0..end]
}

/// Get the path data of the first `i` components.
///
/// # Safety
///
/// Memory must have been initialised with a prior call to set_component_accumulated_length for `i - 1` (ignored if `i == 0`).
/// Also, the path data must have been initialised via a reference obtained from `self.path_data_as_mut()`.
unsafe fn get_path_data(&self, i: usize) -> &[u8] {
let end = if i == 0 {
0
} else {
self.get_component_accumulated_length(i - 1)
};
return MaybeUninit::slice_assume_init_ref(&self.path_data[..end]);
}

/// Copy the data from this struct into a new Path of `i` components.
///
/// # Safety
///
/// The first `i` accumulated component lengths must have been set, and all corresponding path data must be initialised. MCL, MCC, and MCP are trusted blindly and must be adhered to by the data in the scratch buffer.
pub unsafe fn to_path<const MCL: usize>(&self, i: usize) -> Path<MCL, MCC, MPL> {
if i == 0 {
Path::new_empty()
} else {
let total_length = if i == 0 {
0
} else {
self.get_component_accumulated_length(i - 1)
};
let mut buf = BytesMut::with_capacity((size_of::<usize>() * (i + 1)) + total_length);

buf.extend_from_slice(&i.to_ne_bytes());
buf.extend_from_slice(self.get_accumumulated_component_lengths(i));
buf.extend_from_slice(self.get_path_data(i));

Path::from_buffer_and_component_count(buf.freeze(), i)
}
}
}
9 changes: 8 additions & 1 deletion data-model/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
#![feature(new_uninit, async_fn_traits, debug_closure_helpers)]
#![feature(
new_uninit,
async_fn_traits,
debug_closure_helpers,
maybe_uninit_uninit_array,
maybe_uninit_write_slice,
maybe_uninit_slice
)]

pub mod encoding;
pub mod entry;
Expand Down
44 changes: 30 additions & 14 deletions data-model/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::encoding::{
error::DecodeError,
max_power::{decode_max_power, encode_max_power},
parameters::{Decodable, Encodable},
shared_buffers::ScratchSpacePathDecoding,
};

// This struct is tested in `fuzz/path.rs`, `fuzz/path2.rs`, `fuzz/path3.rs`, `fuzz/path3.rs` by comparing against a non-optimised reference implementation.
Expand Down Expand Up @@ -660,12 +661,16 @@ impl<const MCL: usize, const MCC: usize, const MPL: usize> Path<MCL, MCC, MPL> {
None
}

fn from_buffer_and_component_count(buf: Bytes, component_count: usize) -> Self {
pub(crate) fn from_buffer_and_component_count(buf: Bytes, component_count: usize) -> Self {
Path {
data: HeapEncoding(buf),
component_count,
}
}

pub(crate) fn raw_buf(&self) -> &[u8] {
self.data.0.as_ref()
}
}

impl<const MCL: usize, const MCC: usize, const MPL: usize> PartialEq for Path<MCL, MCC, MPL> {
Expand Down Expand Up @@ -808,26 +813,37 @@ impl<const MCL: usize, const MCC: usize, const MPL: usize> Decodable for Path<MC
where
P: BulkProducer<Item = u8>,
{
let component_count = decode_max_power(MCC, producer).await?;
let component_count: usize = decode_max_power(MCC, producer).await?.try_into()?;
if component_count > MCC {
return Err(DecodeError::InvalidInput);
}

let mut path = Self::new_empty();
let mut buf = ScratchSpacePathDecoding::<MCC, MPL>::new();

for _ in 0..component_count {
let component_len = decode_max_power(MCL, producer).await?;
let mut accumulated_component_length: usize = 0; // Always holds the acc length of all components we copied so far.
for i in 0..component_count {
let component_len: usize = decode_max_power(MCL, producer).await?.try_into()?;
if component_len > MCL {
return Err(DecodeError::InvalidInput);
}

accumulated_component_length += component_len;
if accumulated_component_length > MPL {
return Err(DecodeError::InvalidInput);
}

let mut component_box = Box::new_uninit_slice(usize::try_from(component_len)?);
buf.set_component_accumulated_length(accumulated_component_length, i);

let slice = producer
.bulk_overwrite_full_slice_uninit(component_box.as_mut())
// Decode the component itself into the scratch buffer.
producer
.bulk_overwrite_full_slice_uninit(unsafe {
// Safe because we called set_component_Accumulated_length for all j <= i
buf.path_data_as_mut(i)
})
.await?;

let path_component = Component::new(slice).ok_or(DecodeError::InvalidInput)?;
path = path
.append(path_component)
.map_err(|_| DecodeError::InvalidInput)?;
}

Ok(path)
Ok(unsafe { buf.to_path(component_count) })
}
}

Expand Down
Loading

0 comments on commit 14dbe1a

Please sign in to comment.