Skip to content

Commit

Permalink
MerkleDB Compact Path Bytes (#2010)
Browse files Browse the repository at this point in the history
Signed-off-by: David Boehm <91908103+dboehm-avalabs@users.noreply.github.com>
Co-authored-by: Dan Laine <daniel.laine@avalabs.org>
  • Loading branch information
dboehm-avalabs and Dan Laine authored Sep 29, 2023
1 parent 08044ba commit f79d609
Show file tree
Hide file tree
Showing 32 changed files with 1,955 additions and 1,062 deletions.
170 changes: 84 additions & 86 deletions proto/pb/sync/sync.pb.go

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions proto/sync/sync.proto
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ message RangeProof {
}

message ProofNode {
SerializedPath key = 1;
Path key = 1;
MaybeBytes value_or_hash = 2;
map<uint32, bytes> children = 3;
}
Expand All @@ -149,8 +149,8 @@ message KeyChange {
MaybeBytes value = 2;
}

message SerializedPath {
uint64 nibble_length = 1;
message Path {
uint64 length = 1;
bytes value = 2;
}

Expand Down
102 changes: 49 additions & 53 deletions x/merkledb/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"sync"
Expand All @@ -22,10 +21,10 @@ const (
falseByte = 0
minVarIntLen = 1
minMaybeByteSliceLen = boolLen
minSerializedPathLen = minVarIntLen
minPathLen = minVarIntLen
minByteSliceLen = minVarIntLen
minDBNodeLen = minMaybeByteSliceLen + minVarIntLen
minChildLen = minVarIntLen + minSerializedPathLen + ids.IDLen + boolLen
minChildLen = minVarIntLen + minPathLen + ids.IDLen + boolLen

estimatedKeyLen = 64
estimatedValueLen = 64
Expand All @@ -42,13 +41,13 @@ var (
trueBytes = []byte{trueByte}
falseBytes = []byte{falseByte}

errTooManyChildren = fmt.Errorf("length of children list is larger than branching factor of %d", NodeBranchFactor)
errChildIndexTooLarge = fmt.Errorf("invalid child index. Must be less than branching factor of %d", NodeBranchFactor)
errLeadingZeroes = errors.New("varint has leading zeroes")
errInvalidBool = errors.New("decoded bool is neither true nor false")
errNonZeroNibblePadding = errors.New("nibbles should be padded with 0s")
errExtraSpace = errors.New("trailing buffer space")
errIntOverflow = errors.New("value overflows int")
errTooManyChildren = errors.New("length of children list is larger than branching factor")
errChildIndexTooLarge = errors.New("invalid child index. Must be less than branching factor")
errLeadingZeroes = errors.New("varint has leading zeroes")
errInvalidBool = errors.New("decoded bool is neither true nor false")
errNonZeroPathPadding = errors.New("path partial byte should be padded with 0s")
errExtraSpace = errors.New("trailing buffer space")
errIntOverflow = errors.New("value overflows int")
)

// encoderDecoder defines the interface needed by merkleDB to marshal
Expand All @@ -60,14 +59,14 @@ type encoderDecoder interface {

type encoder interface {
// Assumes [n] is non-nil.
encodeDBNode(n *dbNode) []byte
encodeDBNode(n *dbNode, factor BranchFactor) []byte
// Assumes [hv] is non-nil.
encodeHashValues(hv *hashValues) []byte
}

type decoder interface {
// Assumes [n] is non-nil.
decodeDBNode(bytes []byte, n *dbNode) error
decodeDBNode(bytes []byte, n *dbNode, factor BranchFactor) error
}

func newCodec() encoderDecoder {
Expand All @@ -88,7 +87,7 @@ type codecImpl struct {
varIntPool sync.Pool
}

func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
func (c *codecImpl) encodeDBNode(n *dbNode, branchFactor BranchFactor) []byte {
var (
numChildren = len(n.children)
// Estimate size of [n] to prevent memory allocations
Expand All @@ -100,11 +99,10 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
c.encodeUint(buf, uint64(numChildren))
// Note we insert children in order of increasing index
// for determinism.
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := n.children[index]; ok {
for index := 0; BranchFactor(index) < branchFactor; index++ {
if entry, ok := n.children[byte(index)]; ok {
c.encodeUint(buf, uint64(index))
path := entry.compressedPath.Serialize()
c.encodeSerializedPath(buf, path)
c.encodePath(buf, entry.compressedPath)
_, _ = buf.Write(entry.id[:])
c.encodeBool(buf, entry.hasValue)
}
Expand All @@ -123,19 +121,19 @@ func (c *codecImpl) encodeHashValues(hv *hashValues) []byte {
c.encodeUint(buf, uint64(numChildren))

// ensure that the order of entries is consistent
for index := byte(0); index < NodeBranchFactor; index++ {
if entry, ok := hv.Children[index]; ok {
for index := 0; BranchFactor(index) < hv.Key.branchFactor; index++ {
if entry, ok := hv.Children[byte(index)]; ok {
c.encodeUint(buf, uint64(index))
_, _ = buf.Write(entry.id[:])
}
}
c.encodeMaybeByteSlice(buf, hv.Value)
c.encodeSerializedPath(buf, hv.Key)
c.encodePath(buf, hv.Key)

return buf.Bytes()
}

func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
func (c *codecImpl) decodeDBNode(b []byte, n *dbNode, branchFactor BranchFactor) error {
if minDBNodeLen > len(b) {
return io.ErrUnexpectedEOF
}
Expand All @@ -152,25 +150,25 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
switch {
case err != nil:
return err
case numChildren > NodeBranchFactor:
case numChildren > uint64(branchFactor):
return errTooManyChildren
case numChildren > uint64(src.Len()/minChildLen):
return io.ErrUnexpectedEOF
}

n.children = make(map[byte]child, NodeBranchFactor)
n.children = make(map[byte]child, branchFactor)
var previousChild uint64
for i := uint64(0); i < numChildren; i++ {
index, err := c.decodeUint(src)
if err != nil {
return err
}
if index >= NodeBranchFactor || (i != 0 && index <= previousChild) {
if index >= uint64(branchFactor) || (i != 0 && index <= previousChild) {
return errChildIndexTooLarge
}
previousChild = index

compressedPath, err := c.decodeSerializedPath(src)
compressedPath, err := c.decodePath(src, branchFactor)
if err != nil {
return err
}
Expand All @@ -183,7 +181,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
return err
}
n.children[byte(index)] = child{
compressedPath: compressedPath.deserialize(),
compressedPath: compressedPath,
id: childID,
hasValue: hasValue,
}
Expand Down Expand Up @@ -328,47 +326,45 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
return id, err
}

func (c *codecImpl) encodeSerializedPath(dst *bytes.Buffer, s SerializedPath) {
c.encodeUint(dst, uint64(s.NibbleLength))
_, _ = dst.Write(s.Value)
func (c *codecImpl) encodePath(dst *bytes.Buffer, p Path) {
c.encodeUint(dst, uint64(p.tokensLength))
_, _ = dst.Write(p.Bytes())
}

func (c *codecImpl) decodeSerializedPath(src *bytes.Reader) (SerializedPath, error) {
if minSerializedPathLen > src.Len() {
return SerializedPath{}, io.ErrUnexpectedEOF
func (c *codecImpl) decodePath(src *bytes.Reader, branchFactor BranchFactor) (Path, error) {
if minPathLen > src.Len() {
return Path{}, io.ErrUnexpectedEOF
}

nibbleLength, err := c.decodeUint(src)
length, err := c.decodeUint(src)
if err != nil {
return SerializedPath{}, err
}
if nibbleLength > math.MaxInt {
return SerializedPath{}, errIntOverflow
}

result := SerializedPath{
NibbleLength: int(nibbleLength),
return Path{}, err
}
pathBytesLen := result.NibbleLength >> 1
hasOddLen := result.hasOddLength()
if hasOddLen {
pathBytesLen++
if length > math.MaxInt {
return Path{}, errIntOverflow
}
result := emptyPath(branchFactor)
result.tokensLength = int(length)
pathBytesLen := result.bytesNeeded(result.tokensLength)
if pathBytesLen > src.Len() {
return SerializedPath{}, io.ErrUnexpectedEOF
return Path{}, io.ErrUnexpectedEOF
}
result.Value = make([]byte, pathBytesLen)
if _, err := io.ReadFull(src, result.Value); err != nil {
buffer := make([]byte, pathBytesLen)
if _, err := io.ReadFull(src, buffer); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return SerializedPath{}, err
return Path{}, err
}
if hasOddLen {
paddedNibble := result.Value[pathBytesLen-1] & 0x0F
if paddedNibble != 0 {
return SerializedPath{}, errNonZeroNibblePadding
if result.hasPartialByte() {
// Confirm that the padding bits in the partial byte are 0.
// We want to only look at the bits to the right of the last token, which is at index length-1.
// Generate a mask with (8-bitsToShift) 0s followed by bitsToShift 1s.
paddingMask := byte(0xFF >> (8 - result.bitsToShift(result.tokensLength-1)))
if buffer[pathBytesLen-1]&paddingMask != 0 {
return Path{}, errNonZeroPathPadding
}
}
result.value = string(buffer)
return result, nil
}
Loading

0 comments on commit f79d609

Please sign in to comment.