diff --git a/docs/radix-memory.md b/docs/radix-memory.md new file mode 100644 index 0000000..9252866 --- /dev/null +++ b/docs/radix-memory.md @@ -0,0 +1,115 @@ +# Benchmarking Radix Trie + +Asterisc moved away from its hashmap-based memory structure to a radix-trie based memory structure. + +This was done in order to: + +1. Improve client diversity by differentiating the implementation from cannon +2. Improve runtime performance + +In radix trie, the branching factor and the depth of the trie is critical. Depending on the sparsity of the dataset, we must adjust the radix trie to best suit the program it runs. + +- Larger radix branching factors can lead to less levels and larger node sizes, which can lead to less pointer indirection and depth traversal while the larger node sizes leads to more memory footprint. Larger radix node also require more computation to generate a merkle root of a single node. +- Smaller radix branching factors lead to more levels and smaller node sizes, which have contrary impact compared to above. + +There are two methods we used to benchmark the change to radix trie. + +Multiple variants of radix trie were tested, with different branching factors. + +Here’s the list of asterisc implemented with different configs: + +| Variant | Radix type | +| --- | --- | +| Asterisc v1.0.0 | non-radix | +| Radix 1 | [4,4,4,4,4,4,4,8,8,8] - 10 levels | +| Radix 2 | [8,8,8,8,8,8,4] - 7 levels | +| Radix 3 | [4,4,4,4,4,4,4,4,4,4,4,4,4] - 13 levels | +| Radix 4 | [8,8,8,8,8,4,4,4] - 8 levels | +| Radix 5 | [16,16,6,6,4,4] - 6 levels | +| Radix 6 | [16,16,6,6,4,2,2] - 7 levels | +| Radix 7 | [10,10,10,10,12] - 5 levels | + +## Benchmark Unit Test + +New benchmark suite is added, which measures the latency of the following operations; + +- Memory read / write to random addresses +- Memory read / write to contiguous address +- Memory write to sparse memory addresse +- Memory write to dense memory addresses +- Merkle proof generation +- Merkle root calculation + +For the above cases, each asterisc implementation had the following results: + +| | Asteris v1.0.0 | Radix 1 | Radix 2 | Radix 3 | Radix 4 | Radix 5 | Radix 6 | Radix 7 | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| RandomReadWrite | 17.9n | 15.96 | 15.62 | 16.58 | 15.82 | 16.2 | 18.98 | 15.89 | +| SequentialReadWrite | 5.68n | 4.386 | 4.214 | 4.177 | 4.242 | 4.335 | 4.573 | 4.238 | +| SparseMemoryUsage | 4.964µ | 5.845 | 6.317 | 5.187 | 5.93 | 4.954 | 8.265 | 24.567 | +| DenseMemoryUsage | 11.73n | 9.094 | 9.649 | 10.11 | 10.12 | 10.4 | 10.12 | 10.12 | +| MerkleProof | 1.97µ | 1.441 | 1.464 | 1.611 | 1.737 | 1.604 | 1.737 | 1.98 | +| MerkleRoot | 6.129n | 4.536 | 4.52 | 4.509 | 4.648 | 4.623 | 4.746 | 4.928 | + +Above statistics are based on `sec/op` . Most of the results show that radix-based implementation is faster than the previous hashmap-based memory, except for few outliers. + +Note that this does not account for memory usage such as `B/op` and `allocs/op`. As explained above, each initialization of radix-trie node allocates more memory than a hashmap would, leading to usually larger memory footprint. + +## Full op-program run + +For a more realistic performance of asterisc, we need to run it against the real chain data by running it as a VM client of op-program. + +Tests were done on Asterisc running with Kona, for op-sepolia at [block#17484899](https://sepolia-optimism.etherscan.io/block/17484899) + +| | Average | Min | Max | % from v1.0.0 | +| --- | --- | --- | --- | --- | +| Asterisc v1.0.0 | 112.759 | 109.345 | 116.045 | 0.00% | +| Radix 1 | 110.349 | 109.418 | 112.149 | -2.14% | +| Radix 2 | 109.9 | 107.526 | 111.398 | -2.54% | +| Radix 3 | 110.589 | 107.814 | 113.544 | -1.92% | +| Radix 4 | 106.902 | 103.71 | 110.453 | -5.19% | +| Radix 5 | 106.605 | 104.469 | 109.754 | -5.46% | +| Radix 6 | 109.137 | 106.764 | 111.819 | -3.21% | +| Radix 7 | 111.163 | 110.392 | 111.634 | -1.42% | +| Radix 4 w/ pgo | 98.742 | 97.055 | 101.035 | -12.43% | + +As you can see above, radix 4/5 had the best results compared to original asterisc implementation, with more than 5% improvement in op-program run duration. + +After applying [pgo(profile-guided optimization)](https://go.dev/doc/pgo) on radix 4, we can observe over 12% improvement in speed. + +## Visualizing address allocation pattern + +In this radix-trie implementation, only the memory addresses that are actually allocatd are initialized as radix trie. Therefore, we can look at the overall memory allocation pattern to see how we can optimize the radix branching factor. + +| Radix level (4bit each) | Allocations | +| --- | --- | +| 1 | 2 | +| 2 | 2 | +| 3 | 2 | +| 4 | 2 | +| 5 | 2 | +| 6 | 2 | +| 7 | 2 | +| 8 | 2 | +| 9 | 2 | +| 10 | 3 | +| 11 | 33 | +| 12 | 502 | +| 13 | 7982 | + +Above graph is allocation count during a full op-program run, where the full address space(52 bits) are split into 13 nodes(4 bits each). + +We can observe that the memory allocation is very sparse in the upper parts of the memory address, while it is heavily dense in the lower part of the memory address. + +With only couple of allocation for 36bit-upper memory region, we could generalize that most of the op-program runs are confined to lower memory address regions. + +## Conclusion + +Based on above observations, and our goal of improving runtime performance, we decided on using `radix 5 (16, 16, 6, 6, 4, 4)` + +Usually, sparse region would utilize smaller branching factor for memory optimization. However, since our goal is faster performance, utilizing larger levels at upper memory region and reducing trie traversal depth. + +- use larger branching factors at the upper address level to reduce the trie traversal depth +- use smaller branching factors at the lower address level to reduce computation for each node. + +In addition, we can apply pgo as mentioned above. To apply pgo to asterisc builds, we can run asterisc with cpu pprof enabled, and ship asterisc with `default.pgo` in the build path. This way, whenever the user builds Asterisc, pgo will be enabled by default, leading to addition 5+% improvement in speed. \ No newline at end of file diff --git a/rvgo/fast/memory.go b/rvgo/fast/memory.go index 50ac86a..3e41ae1 100644 --- a/rvgo/fast/memory.go +++ b/rvgo/fast/memory.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "math/bits" "sort" "github.com/ethereum/go-ethereum/crypto" @@ -39,14 +38,12 @@ var zeroHashes = func() [256][32]byte { }() type Memory struct { - // generalized index -> merkle root or nil if invalidated - nodes map[uint64]*[32]byte - - // pageIndex -> cached page - pages map[uint64]*CachedPage + radix *L1 + branchFactors [6]uint64 // Note: since we don't de-alloc pages, we don't do ref-counting. // Once a page exists, it doesn't leave memory + pages map[uint64]*CachedPage // two caches: we often read instructions from one page, and do memory things with another page. // this prevents map lookups each instruction @@ -55,10 +52,12 @@ type Memory struct { } func NewMemory() *Memory { + node := &L1{} return &Memory{ - nodes: make(map[uint64]*[32]byte), - pages: make(map[uint64]*CachedPage), - lastPageKeys: [2]uint64{^uint64(0), ^uint64(0)}, // default to invalid keys, to not match any pages + radix: node, + pages: make(map[uint64]*CachedPage), + branchFactors: [6]uint64{16, 16, 6, 6, 4, 4}, + lastPageKeys: [2]uint64{^uint64(0), ^uint64(0)}, // default to invalid keys, to not match any pages } } @@ -75,90 +74,6 @@ func (m *Memory) ForEachPage(fn func(pageIndex uint64, page *Page) error) error return nil } -func (m *Memory) Invalidate(addr uint64) { - // find page, and invalidate addr within it - if p, ok := m.pageLookup(addr >> PageAddrSize); ok { - prevValid := p.Ok[1] - p.Invalidate(addr & PageAddrMask) - if !prevValid { // if the page was already invalid before, then nodes to mem-root will also still be. - return - } - } else { // no page? nothing to invalidate - return - } - - // find the gindex of the first page covering the address - gindex := (uint64(addr) >> PageAddrSize) | (1 << (64 - PageAddrSize)) - - for gindex > 0 { - m.nodes[gindex] = nil - gindex >>= 1 - } -} - -func (m *Memory) MerkleizeSubtree(gindex uint64) [32]byte { - l := uint64(bits.Len64(gindex)) - if l > ProofLen { - panic("gindex too deep") - } - if l > PageKeySize { - depthIntoPage := l - 1 - PageKeySize - pageIndex := (gindex >> depthIntoPage) & PageKeyMask - if p, ok := m.pages[uint64(pageIndex)]; ok { - pageGindex := (1 << depthIntoPage) | (gindex & ((1 << depthIntoPage) - 1)) - return p.MerkleizeSubtree(pageGindex) - } else { - return zeroHashes[64-5+1-l] // page does not exist - } - } - n, ok := m.nodes[gindex] - if !ok { - // if the node doesn't exist, the whole sub-tree is zeroed - return zeroHashes[64-5+1-l] - } - if n != nil { - return *n - } - left := m.MerkleizeSubtree(gindex << 1) - right := m.MerkleizeSubtree((gindex << 1) | 1) - r := HashPair(left, right) - m.nodes[gindex] = &r - return r -} - -func (m *Memory) MerkleProof(addr uint64) (out [ProofLen * 32]byte) { - proof := m.traverseBranch(1, addr, 0) - // encode the proof - for i := 0; i < ProofLen; i++ { - copy(out[i*32:(i+1)*32], proof[i][:]) - } - return out -} - -func (m *Memory) traverseBranch(parent uint64, addr uint64, depth uint8) (proof [][32]byte) { - if depth == ProofLen-1 { - proof = make([][32]byte, 0, ProofLen) - proof = append(proof, m.MerkleizeSubtree(parent)) - return - } - if depth > ProofLen-1 { - panic("traversed too deep") - } - self := parent << 1 - sibling := self | 1 - if addr&(1<<(63-depth)) != 0 { - self, sibling = sibling, self - } - proof = m.traverseBranch(self, addr, depth+1) - siblingNode := m.MerkleizeSubtree(sibling) - proof = append(proof, siblingNode) - return -} - -func (m *Memory) MerkleRoot() [32]byte { - return m.MerkleizeSubtree(1) -} - func (m *Memory) pageLookup(pageIndex uint64) (*CachedPage, bool) { // hit caches if pageIndex == m.lastPageKeys[0] { @@ -257,18 +172,6 @@ func (m *Memory) GetUnaligned(addr uint64, dest []byte) { } } -func (m *Memory) AllocPage(pageIndex uint64) *CachedPage { - p := &CachedPage{Data: new(Page)} - m.pages[pageIndex] = p - // make nodes to root - k := (1 << PageKeySize) | uint64(pageIndex) - for k > 0 { - m.nodes[k] = nil - k >>= 1 - } - return p -} - type pageEntry struct { Index uint64 `json:"index"` Data *Page `json:"data"` @@ -293,7 +196,9 @@ func (m *Memory) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &pages); err != nil { return err } - m.nodes = make(map[uint64]*[32]byte) + + m.branchFactors = [6]uint64{16, 16, 6, 6, 4, 4} + m.radix = &L1{} m.pages = make(map[uint64]*CachedPage) m.lastPageKeys = [2]uint64{^uint64(0), ^uint64(0)} m.lastPage = [2]*CachedPage{nil, nil} diff --git a/rvgo/fast/memory_benchmark_test.go b/rvgo/fast/memory_benchmark_test.go new file mode 100644 index 0000000..dae8870 --- /dev/null +++ b/rvgo/fast/memory_benchmark_test.go @@ -0,0 +1,129 @@ +package fast + +import ( + "math/rand" + "testing" +) + +const ( + smallDataset = 1_000 + mediumDataset = 100_000 + largeDataset = 1_000_000 +) + +func BenchmarkMemoryOperations(b *testing.B) { + benchmarks := []struct { + name string + fn func(b *testing.B, m *Memory) + }{ + {"RandomReadWrite_Small", benchRandomReadWrite(smallDataset)}, + {"RandomReadWrite_Medium", benchRandomReadWrite(mediumDataset)}, + {"RandomReadWrite_Large", benchRandomReadWrite(largeDataset)}, + {"SequentialReadWrite_Small", benchSequentialReadWrite(smallDataset)}, + {"SequentialReadWrite_Large", benchSequentialReadWrite(largeDataset)}, + {"SparseMemoryUsage", benchSparseMemoryUsage}, + {"DenseMemoryUsage", benchDenseMemoryUsage}, + {"SmallFrequentUpdates", benchSmallFrequentUpdates}, + {"MerkleProofGeneration_Small", benchMerkleProofGeneration(smallDataset)}, + {"MerkleProofGeneration_Large", benchMerkleProofGeneration(largeDataset)}, + {"MerkleRootCalculation_Small", benchMerkleRootCalculation(smallDataset)}, + {"MerkleRootCalculation_Large", benchMerkleRootCalculation(largeDataset)}, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + m := NewMemory() + b.ResetTimer() + bm.fn(b, m) + }) + } +} + +func benchRandomReadWrite(size int) func(b *testing.B, m *Memory) { + return func(b *testing.B, m *Memory) { + addresses := make([]uint64, size) + for i := range addresses { + addresses[i] = rand.Uint64() + } + data := make([]byte, 8) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + addr := addresses[i%len(addresses)] + if i%2 == 0 { + m.SetUnaligned(addr, data) + } else { + m.GetUnaligned(addr, data) + } + } + } +} + +func benchSequentialReadWrite(size int) func(b *testing.B, m *Memory) { + return func(b *testing.B, m *Memory) { + data := make([]byte, 8) + b.ResetTimer() + for i := 0; i < b.N; i++ { + addr := uint64(i % size) + if i%2 == 0 { + m.SetUnaligned(addr, data) + } else { + m.GetUnaligned(addr, data) + } + } + } +} + +func benchSparseMemoryUsage(b *testing.B, m *Memory) { + data := make([]byte, 8) + b.ResetTimer() + for i := 0; i < b.N; i++ { + addr := uint64(i) * 10_000_000 // Large gaps between addresses + m.SetUnaligned(addr, data) + } +} + +func benchDenseMemoryUsage(b *testing.B, m *Memory) { + data := make([]byte, 8) + b.ResetTimer() + for i := 0; i < b.N; i++ { + addr := uint64(i) * 8 // Contiguous 8-byte allocations + m.SetUnaligned(addr, data) + } +} + +func benchSmallFrequentUpdates(b *testing.B, m *Memory) { + data := make([]byte, 1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + addr := uint64(rand.Intn(1000000)) // Confined to a smaller range + m.SetUnaligned(addr, data) + } +} + +func benchMerkleProofGeneration(size int) func(b *testing.B, m *Memory) { + return func(b *testing.B, m *Memory) { + // Setup: allocate some memory + for i := 0; i < size; i++ { + m.SetUnaligned(uint64(i)*8, []byte{byte(i)}) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + addr := uint64(rand.Intn(size) * 8) + _ = m.MerkleProof(addr) + } + } +} + +func benchMerkleRootCalculation(size int) func(b *testing.B, m *Memory) { + return func(b *testing.B, m *Memory) { + // Setup: allocate some memory + for i := 0; i < size; i++ { + m.SetUnaligned(uint64(i)*8, []byte{byte(i)}) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = m.MerkleRoot() + } + } +} diff --git a/rvgo/fast/memory_test.go b/rvgo/fast/memory_test.go index 0c4a955..8653a6f 100644 --- a/rvgo/fast/memory_test.go +++ b/rvgo/fast/memory_test.go @@ -2,7 +2,7 @@ package fast import ( "bytes" - "crypto/rand" + cryptorand "crypto/rand" "encoding/binary" "encoding/json" "io" @@ -22,6 +22,7 @@ func TestMemoryMerkleProof(t *testing.T) { require.Equal(t, zeroHashes[i][:], proof[32+i*32:32+i*32+32], "empty siblings") } }) + t.Run("fuller tree", func(t *testing.T) { m := NewMemory() m.SetUnaligned(0x10000, []byte{0xaa, 0xbb, 0xcc, 0xdd}) @@ -31,7 +32,7 @@ func TestMemoryMerkleProof(t *testing.T) { proof := m.MerkleProof(0x80004) require.Equal(t, uint32(42<<24), binary.BigEndian.Uint32(proof[4:8])) node := *(*[32]byte)(proof[:32]) - path := uint32(0x80004) >> 5 + path := 0x80004 >> 5 for i := 32; i < len(proof); i += 32 { sib := *(*[32]byte)(proof[i : i+32]) if path&1 != 0 { @@ -43,6 +44,222 @@ func TestMemoryMerkleProof(t *testing.T) { } require.Equal(t, root, node, "proof must verify") }) + + t.Run("consistency test", func(t *testing.T) { + m := NewMemory() + addr := uint64(0x1234560000000) + m.SetUnaligned(addr, []byte{1}) + proof1 := m.MerkleProof(addr) + proof2 := m.MerkleProof(addr) + require.Equal(t, proof1, proof2, "Proofs for the same address should be consistent") + }) + + t.Run("stress test", func(t *testing.T) { + m := NewMemory() + var addresses []uint64 + for i := uint64(0); i < 10000; i++ { + addr := i * 0x1000000 // Spread out addresses + addresses = append(addresses, addr) + m.SetUnaligned(addr, []byte{byte(i + 1)}) + } + root := m.MerkleRoot() + for _, addr := range addresses { + proof := m.MerkleProof(addr) + verifyProof(t, root, proof, addr) + } + }) + t.Run("boundary addresses", func(t *testing.T) { + m := NewMemory() + addresses := []uint64{ + //0x0000000000000 - 1, // Just before first level + 0x0000000000000, // Start of first level + 0x0400000000000 - 1, // End of first level + 0x0400000000000, // Start of second level + 0x3C00000000000 - 1, // End of fourth level + 0x3C00000000000, // Start of fifth level + 0x3FFFFFFFFFFF, // Maximum address + } + for i, addr := range addresses { + m.SetUnaligned(addr, []byte{byte(i + 1)}) + } + root := m.MerkleRoot() + for _, addr := range addresses { + proof := m.MerkleProof(addr) + verifyProof(t, root, proof, addr) + } + }) + t.Run("multiple levels", func(t *testing.T) { + m := NewMemory() + addresses := []uint64{ + 0x0000000000000, + 0x0400000000000, + 0x0800000000000, + 0x0C00000000000, + 0x1000000000000, + 0x1400000000000, + } + for i, addr := range addresses { + m.SetUnaligned(addr, []byte{byte(i + 1)}) + } + root := m.MerkleRoot() + for _, addr := range addresses { + proof := m.MerkleProof(addr) + verifyProof(t, root, proof, addr) + } + }) + + t.Run("sparse tree", func(t *testing.T) { + m := NewMemory() + addresses := []uint64{ + 0x0000000000000, + 0x0000400000000, + 0x0004000000000, + 0x0040000000000, + 0x0400000000000, + 0x3C00000000000, + } + for i, addr := range addresses { + m.SetUnaligned(addr, []byte{byte(i + 1)}) + } + root := m.MerkleRoot() + for _, addr := range addresses { + proof := m.MerkleProof(addr) + verifyProof(t, root, proof, addr) + } + }) + + t.Run("adjacent addresses", func(t *testing.T) { + m := NewMemory() + baseAddr := uint64(0x0400000000000) + for i := uint64(0); i < 16; i++ { + m.SetUnaligned(baseAddr+i, []byte{byte(i + 1)}) + } + root := m.MerkleRoot() + for i := uint64(0); i < 16; i++ { + proof := m.MerkleProof(baseAddr + i) + verifyProof(t, root, proof, baseAddr+i) + } + }) + + t.Run("cross-page addresses", func(t *testing.T) { + m := NewMemory() + pageSize := uint64(4096) + addresses := []uint64{ + pageSize - 2, + pageSize - 1, + pageSize, + pageSize + 1, + 2*pageSize - 2, + 2*pageSize - 1, + 2 * pageSize, + 2*pageSize + 1, + } + for i, addr := range addresses { + m.SetUnaligned(addr, []byte{byte(i + 1)}) + } + root := m.MerkleRoot() + for _, addr := range addresses { + proof := m.MerkleProof(addr) + verifyProof(t, root, proof, addr) + } + }) + + t.Run("large addresses", func(t *testing.T) { + m := NewMemory() + addresses := []uint64{ + 0x10_00_00_00_00_00_00_00, + 0x10_00_00_00_00_00_00_02, + 0x10_00_00_00_00_00_00_04, + 0x10_00_00_00_00_00_00_06, + } + for i, addr := range addresses { + m.SetUnaligned(addr, []byte{byte(i + 1)}) + } + root := m.MerkleRoot() + for _, addr := range addresses { + proof := m.MerkleProof(addr) + verifyProof(t, root, proof, addr) + } + }) +} +func TestMerkleProofWithPartialPaths(t *testing.T) { + testCases := []struct { + name string + setupMemory func(*Memory) + proofAddr uint64 + }{ + { + name: "Path ends at level 1", + setupMemory: func(m *Memory) { + m.SetUnaligned(0x10_00_00_00_00_00_00_00, []byte{1}) + }, + proofAddr: 0x20_00_00_00_00_00_00_00, + }, + { + name: "Path ends at level 2", + setupMemory: func(m *Memory) { + m.SetUnaligned(0x10_00_00_00_00_00_00_00, []byte{1}) + }, + proofAddr: 0x11_00_00_00_00_00_00_00, + }, + { + name: "Path ends at level 3", + setupMemory: func(m *Memory) { + m.SetUnaligned(0x10_10_00_00_00_00_00_00, []byte{1}) + }, + proofAddr: 0x10_11_00_00_00_00_00_00, + }, + { + name: "Path ends at level 4", + setupMemory: func(m *Memory) { + m.SetUnaligned(0x10_10_10_00_00_00_00_00, []byte{1}) + }, + proofAddr: 0x10_10_11_00_00_00_00_00, + }, + { + name: "Full path to level 5, page doesn't exist", + setupMemory: func(m *Memory) { + m.SetUnaligned(0x10_10_10_10_00_00_00_00, []byte{1}) + }, + proofAddr: 0x10_10_10_10_10_00_00_00, // Different page in the same level 5 node + }, + { + name: "Path ends at level 3, check different page offsets", + setupMemory: func(m *Memory) { + m.SetUnaligned(0x10_10_00_00_00_00_00_00, []byte{1}) + m.SetUnaligned(0x10_10_00_00_00_00_10_00, []byte{2}) + }, + proofAddr: 0x10_10_00_00_00_00_20_00, // Different offset in the same page + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m := NewMemory() + tc.setupMemory(m) + + proof := m.MerkleProof(tc.proofAddr) + + // Check that the proof is filled correctly + verifyProof(t, m.MerkleRoot(), proof, tc.proofAddr) + //checkProof(t, proof, tc.expectedDepth) + }) + } +} + +func verifyProof(t *testing.T, expectedRoot [32]byte, proof [ProofLen * 32]byte, addr uint64) { + node := *(*[32]byte)(proof[:32]) + path := addr >> 5 + for i := 32; i < len(proof); i += 32 { + sib := *(*[32]byte)(proof[i : i+32]) + if path&1 != 0 { + node = HashPair(sib, node) + } else { + node = HashPair(node, sib) + } + path >>= 1 + } + require.Equal(t, expectedRoot, node, "proof must verify for address 0x%x", addr) } func TestMemoryMerkleRoot(t *testing.T) { @@ -77,28 +294,36 @@ func TestMemoryMerkleRoot(t *testing.T) { root := m.MerkleRoot() require.Equal(t, zeroHashes[64-5], root, "zero still") }) + t.Run("random few pages", func(t *testing.T) { m := NewMemory() m.SetUnaligned(PageSize*3, []byte{1}) m.SetUnaligned(PageSize*5, []byte{42}) m.SetUnaligned(PageSize*6, []byte{123}) - p3 := m.MerkleizeSubtree((1 << PageKeySize) | 3) - p5 := m.MerkleizeSubtree((1 << PageKeySize) | 5) - p6 := m.MerkleizeSubtree((1 << PageKeySize) | 6) - z := zeroHashes[PageAddrSize-5] + + p0 := m.radix.MerkleizeNode(0, 8) + p1 := m.radix.MerkleizeNode(0, 9) + p2 := m.radix.MerkleizeNode(0, 10) + p3 := m.radix.MerkleizeNode(0, 11) + p4 := m.radix.MerkleizeNode(0, 12) + p5 := m.radix.MerkleizeNode(0, 13) + p6 := m.radix.MerkleizeNode(0, 14) + p7 := m.radix.MerkleizeNode(0, 15) + r1 := HashPair( HashPair( - HashPair(z, z), // 0,1 - HashPair(z, p3), // 2,3 + HashPair(p0, p1), // 0,1 + HashPair(p2, p3), // 2,3 ), HashPair( - HashPair(z, p5), // 4,5 - HashPair(p6, z), // 6,7 + HashPair(p4, p5), // 4,5 + HashPair(p6, p7), // 6,7 ), ) - r2 := m.MerkleizeSubtree(1 << (PageKeySize - 3)) + r2 := m.MerkleRoot() require.Equal(t, r1, r2, "expecting manual page combination to match subtree merkle func") }) + t.Run("invalidate page", func(t *testing.T) { m := NewMemory() m.SetUnaligned(0xF000, []byte{0}) @@ -114,7 +339,7 @@ func TestMemoryReadWrite(t *testing.T) { t.Run("large random", func(t *testing.T) { m := NewMemory() data := make([]byte, 20_000) - _, err := rand.Read(data[:]) + _, err := cryptorand.Read(data[:]) require.NoError(t, err) require.NoError(t, m.SetMemoryRange(0, bytes.NewReader(data))) for _, i := range []uint64{0, 1, 2, 3, 4, 5, 6, 7, 1000, 3333, 4095, 4096, 4097, 20_000 - 32} { diff --git a/rvgo/fast/page.go b/rvgo/fast/page.go index da5bd0b..24933b4 100644 --- a/rvgo/fast/page.go +++ b/rvgo/fast/page.go @@ -85,3 +85,35 @@ func (p *CachedPage) MerkleizeSubtree(gindex uint64) [32]byte { } return p.Cache[gindex] } + +func (p *CachedPage) MerkleizeNode(addr, gindex uint64) [32]byte { + _ = p.MerkleRoot() // fill cache + if gindex >= PageSize/32 { + if gindex >= PageSize/32*2 { + panic("gindex too deep") + } + + // it's pointing to a bottom node + nodeIndex := gindex & (PageAddrMask >> 5) + return *(*[32]byte)(p.Data[nodeIndex*32 : nodeIndex*32+32]) + } + return p.Cache[gindex] +} + +func (p *CachedPage) GenerateProof(addr uint64) [][32]byte { + // Page-level proof + pageGindex := PageSize>>5 + (addr&PageAddrMask)>>5 + + proofs := make([][32]byte, 8) + proofIndex := 0 + + proofs[proofIndex] = p.MerkleizeSubtree(pageGindex) + + for idx := pageGindex; idx > 1; idx >>= 1 { + sibling := idx ^ 1 + proofIndex++ + proofs[proofIndex] = p.MerkleizeSubtree(uint64(sibling)) + } + + return proofs +} diff --git a/rvgo/fast/radix.go b/rvgo/fast/radix.go new file mode 100644 index 0000000..3cb755b --- /dev/null +++ b/rvgo/fast/radix.go @@ -0,0 +1,465 @@ +package fast + +import ( + "math/bits" +) + +// RadixNode is an interface defining the operations for a node in a radix trie. +type RadixNode interface { + // GenerateProof generates the Merkle proof for the given address. + GenerateProof(addr uint64, proofs [][32]byte) + // MerkleizeNode computes the Merkle root hash for the node at the given generalized index. + MerkleizeNode(addr, gindex uint64) [32]byte +} + +// SmallRadixNode is a radix trie node with a branching factor of 4 bits. +type SmallRadixNode[C RadixNode] struct { + Children [1 << 4]*C // Array of child nodes, indexed by 4-bit keys. + Hashes [1 << 4][32]byte // Cached hashes for intermediate hash node. + HashExists uint16 // Bitmask indicating if the intermediate hash exist (1 bit per intermediate node). + HashValid uint16 // Bitmask indicating if the intermediate hashes are valid (1 bit per intermediate node). + Depth uint64 // The depth of this node in the trie (number of bits from the root). +} + +// MediumRadixNode is a radix trie node with a branching factor of 6 bits. +type MediumRadixNode[C RadixNode] struct { + Children [1 << 6]*C // Array of child nodes, indexed by 6-bit keys. + Hashes [1 << 6][32]byte + HashExists uint64 + HashValid uint64 + Depth uint64 +} + +// LargeRadixNode is a radix trie node with a branching factor of 16 bits. +type LargeRadixNode[C RadixNode] struct { + Children [1 << 16]*C // Array of child nodes, indexed by 16-bit keys. + Hashes [1 << 16][32]byte + HashExists [(1 << 16) / 64]uint64 + HashValid [(1 << 16) / 64]uint64 + Depth uint64 +} + +// Define a sequence of radix trie node types (L1 to L7) representing different levels in the trie. +// Each level corresponds to a node type, where L1 is the root node and L7 is the leaf level pointing to Memory. +// The cumulative bit-lengths of the addresses represented by the nodes from L1 to L7 add up to 52 bits. + +type L1 = LargeRadixNode[L2] +type L2 = *LargeRadixNode[L3] +type L3 = *MediumRadixNode[L4] +type L4 = *MediumRadixNode[L5] +type L5 = *SmallRadixNode[L6] +type L6 = *SmallRadixNode[L7] +type L7 = *Memory + +// InvalidateNode invalidates the hash cache along the path to the specified address. +// It marks the necessary intermediate hashes as invalid, forcing them to be recomputed when needed. +func (n *SmallRadixNode[C]) InvalidateNode(gindex uint64) { + branchIdx := (gindex + 1<<4) / 2 // Compute the index for the hash tree traversal. + + // Traverse up the hash tree, invalidating hashes along the way. + for index := branchIdx; index > 0; index >>= 1 { + hashBit := index & 15 // Get the relevant bit position (0-15). + n.HashExists |= 1 << hashBit // Mark the intermediate hash path as existing. + n.HashValid &= ^(1 << hashBit) // Invalidate the hash at this position. + } +} + +func (n *MediumRadixNode[C]) InvalidateNode(gindex uint64) { + branchIdx := (gindex + 1<<6) / 2 + + for index := branchIdx; index > 0; index >>= 1 { + hashBit := index & 63 + n.HashExists |= 1 << hashBit + n.HashValid &= ^(1 << hashBit) + } +} + +func (n *LargeRadixNode[C]) InvalidateNode(gindex uint64) { + branchIdx := (gindex + 1<<16) / 2 + + for index := branchIdx; index > 0; index >>= 1 { + hashIndex := index >> 6 + hashBit := index & 63 + n.HashExists[hashIndex] |= 1 << hashBit + n.HashValid[hashIndex] &= ^(1 << hashBit) + } +} + +// GenerateProof generates the Merkle proof for the given address. +// It collects the necessary sibling hashes along the path to reconstruct the Merkle proof. +func (n *SmallRadixNode[C]) GenerateProof(addr uint64, proofs [][32]byte) { + path := addressToRadixPath(addr, n.Depth, 4) + + if n.Children[path] == nil { + // When no child exists at this path, the rest of the proofs are zero hashes. + fillZeroHashRange(proofs, 0, 60-n.Depth-4) + } else { + // Recursively generate proofs from the child node. + (*n.Children[path]).GenerateProof(addr, proofs) + } + + // Collect sibling hashes along the path for the proof. + proofIndex := 60 - n.Depth - 4 + for idx := path + 1<<4; idx > 1; idx >>= 1 { + sibling := idx ^ 1 // Get the sibling index. + proofs[proofIndex] = n.MerkleizeNode(addr>>(64-n.Depth), sibling) + proofIndex += 1 + } +} + +func (n *MediumRadixNode[C]) GenerateProof(addr uint64, proofs [][32]byte) { + path := addressToRadixPath(addr, n.Depth, 6) + + if n.Children[path] == nil { + fillZeroHashRange(proofs, 0, 60-n.Depth-6) + } else { + (*n.Children[path]).GenerateProof(addr, proofs) + } + + proofIndex := 60 - n.Depth - 6 + for idx := path + 1<<6; idx > 1; idx >>= 1 { + sibling := idx ^ 1 + proofs[proofIndex] = n.MerkleizeNode(addr>>(64-n.Depth), sibling) + proofIndex += 1 + } +} + +func (n *LargeRadixNode[C]) GenerateProof(addr uint64, proofs [][32]byte) { + path := addressToRadixPath(addr, n.Depth, 16) + + if n.Children[path] == nil { + fillZeroHashRange(proofs, 0, 60-n.Depth-16) + } else { + (*n.Children[path]).GenerateProof(addr, proofs) + } + proofIndex := 60 - n.Depth - 16 + for idx := path + 1<<16; idx > 1; idx >>= 1 { + sibling := idx ^ 1 + proofs[proofIndex] = n.MerkleizeNode(addr>>(64-n.Depth), sibling) + proofIndex += 1 + } +} + +func (m *Memory) GenerateProof(addr uint64, proofs [][32]byte) { + pageIndex := addr >> PageAddrSize + + // number of proof for a page is 8 + // 0: leaf page data, 7: page's root + if p, ok := m.pages[pageIndex]; ok { + pageProofs := p.GenerateProof(addr) // Generate proof from the page. + copy(proofs[:8], pageProofs) + } else { + fillZeroHashRange(proofs, 0, 8) // Return zero hashes if the page does not exist. + } +} + +// MerkleizeNode computes the Merkle root hash for the node at the given generalized index. +// It recursively computes the hash of the subtree rooted at the given index. +// Note: The 'addr' parameter represents the partial address accumulated up to this node, not the full address. It represents the path taken in the trie to reach this node. +func (n *SmallRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte { + depth := uint64(bits.Len64(gindex)) // Get the depth of the current gindex. + + if depth > 5 { + panic("gindex too deep") + } + + // Leaf node of the radix trie (17~32) + if depth > 4 { + childIndex := gindex - 1<<4 + if n.Children[childIndex] == nil { + // Return zero hash if child does not exist. + return zeroHashes[64-5+1-(depth+n.Depth)] + } + + // Update the partial address by appending the child index bits. + // This accumulates the address as we traverse deeper into the trie. + addr <<= 4 + addr |= childIndex + return (*n.Children[childIndex]).MerkleizeNode(addr, 1) + } + + // Intermediate node of the radix trie (0~15) + hashBit := gindex & 15 + + if (n.HashExists & (1 << hashBit)) != 0 { + if (n.HashValid & (1 << hashBit)) != 0 { + // Return the cached hash if valid. + return n.Hashes[gindex] + } else { + left := n.MerkleizeNode(addr, gindex<<1) + right := n.MerkleizeNode(addr, (gindex<<1)|1) + + // Hash the pair and cache the result. + r := HashPair(left, right) + n.Hashes[gindex] = r + n.HashValid |= 1 << hashBit + return r + } + } else { + // Return zero hash for non-existent child. + return zeroHashes[64-5+1-(depth+n.Depth)] + } +} + +func (n *MediumRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte { + depth := uint64(bits.Len64(gindex)) // Get the depth of the current gindex. + + if depth > 7 { + panic("gindex too deep") + } + + // Leaf node of the radix trie (64~128) + if depth > 6 { + childIndex := gindex - 1<<6 + + if n.Children[childIndex] == nil { + // Return zero hash if child does not exist. + return zeroHashes[64-5+1-(depth+n.Depth)] + } + + // Update the partial address by appending the child index bits. + // This accumulates the address as we traverse deeper into the trie. + addr <<= 6 + addr |= childIndex + return (*n.Children[childIndex]).MerkleizeNode(addr, 1) + } + + // Intermediate node of the radix trie (0~16) + hashBit := gindex & 63 + + if (n.HashExists & (1 << hashBit)) != 0 { + if (n.HashValid & (1 << hashBit)) != 0 { + // Return the cached hash if valid. + return n.Hashes[gindex] + } else { + left := n.MerkleizeNode(addr, gindex<<1) + right := n.MerkleizeNode(addr, (gindex<<1)|1) + + // Hash the pair and cache the result. + r := HashPair(left, right) + n.Hashes[gindex] = r + n.HashValid |= 1 << hashBit + return r + } + } else { + // Return zero hash for non-existent child. + return zeroHashes[64-5+1-(depth+n.Depth)] + } +} + +func (n *LargeRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte { + depth := uint64(bits.Len64(gindex)) + + if depth > 17 { + panic("gindex too deep") + } + + // Leaf node of the radix trie (2^15~2^16) + if depth > 16 { + childIndex := gindex - 1<<16 + if n.Children[int(childIndex)] == nil { + return zeroHashes[64-5+1-(depth+n.Depth)] + } + + addr <<= 16 + addr |= childIndex + return (*n.Children[childIndex]).MerkleizeNode(addr, 1) + } + + // Intermediate node of the radix trie (0~2^15) + hashIndex := gindex >> 6 + hashBit := gindex & 63 + if (n.HashExists[hashIndex] & (1 << hashBit)) != 0 { + if (n.HashValid[hashIndex] & (1 << hashBit)) != 0 { + return n.Hashes[gindex] + } else { + left := n.MerkleizeNode(addr, gindex<<1) + right := n.MerkleizeNode(addr, (gindex<<1)|1) + + r := HashPair(left, right) + n.Hashes[gindex] = r + n.HashValid[hashIndex] |= 1 << hashBit + return r + } + } else { + return zeroHashes[64-5+1-(depth+n.Depth)] + } +} + +func (m *Memory) MerkleizeNode(addr, gindex uint64) [32]byte { + depth := uint64(bits.Len64(gindex)) + + pageIndex := addr + if p, ok := m.pages[pageIndex]; ok { + return p.MerkleRoot() + } else { + return zeroHashes[64-5+1-(depth-1+52)] + } +} + +// MerkleRoot computes the Merkle root hash of the entire memory. +func (m *Memory) MerkleRoot() [32]byte { + return (*m.radix).MerkleizeNode(0, 1) +} + +// MerkleProof generates the Merkle proof for the specified address in memory. +func (m *Memory) MerkleProof(addr uint64) [ProofLen * 32]byte { + proofs := make([][32]byte, 60) + m.radix.GenerateProof(addr, proofs) + return encodeProofs(proofs) +} + +// zeroHashRange returns a slice of zero hashes from start to end. +func fillZeroHashRange(slice [][32]byte, start, end uint64) { + if start == 0 { + slice[0] = zeroHashes[0] + start++ + } + for i := start; i < end; i++ { + slice[i] = zeroHashes[i-1] + } +} + +// encodeProofs encodes the list of proof hashes into a byte array. +func encodeProofs(proofs [][32]byte) [ProofLen * 32]byte { + var out [ProofLen * 32]byte + for i := 0; i < ProofLen; i++ { + copy(out[i*32:(i+1)*32], proofs[i][:]) + } + return out +} + +// addressToRadixPath extracts a segment of bits from an address, starting from 'position' with 'count' bits. +// It returns the extracted bits as a uint64. +func addressToRadixPath(addr, position, count uint64) uint64 { + // Calculate the total shift amount. + totalShift := 64 - position - count + + // Shift the address to bring the desired bits to the LSB. + addr >>= totalShift + + // Extract the desired bits using a mask. + return addr & ((1 << count) - 1) +} + +// addressToRadixPaths converts an address into a slice of radix path indices based on the branch factors. +func (m *Memory) addressToRadixPaths(addr uint64) []uint64 { + path := make([]uint64, len(m.branchFactors)) + var position uint64 + + for index, branchFactor := range m.branchFactors { + path[index] = addressToRadixPath(addr, position, branchFactor) + position += branchFactor + } + + return path +} + +// AllocPage allocates a new page at the specified page index in memory. +func (m *Memory) AllocPage(pageIndex uint64) *CachedPage { + p := &CachedPage{Data: new(Page)} + m.pages[pageIndex] = p + + addr := pageIndex << PageAddrSize + branchPaths := m.addressToRadixPaths(addr) + depth := uint64(0) + + // Build the radix trie path to the new page, creating nodes as necessary. + // This code is a bit repetitive, but better for the compiler to optimize. + radixLevel1 := m.radix + depth += m.branchFactors[0] + if (*radixLevel1).Children[branchPaths[0]] == nil { + node := &LargeRadixNode[L3]{Depth: depth} + (*radixLevel1).Children[branchPaths[0]] = &node + } + (*radixLevel1).InvalidateNode(branchPaths[0]) + + radixLevel2 := *(*radixLevel1).Children[branchPaths[0]] + depth += m.branchFactors[1] + if (radixLevel2).Children[branchPaths[1]] == nil { + node := &MediumRadixNode[L4]{Depth: depth} + (radixLevel2).Children[branchPaths[1]] = &node + } + (radixLevel2).InvalidateNode(branchPaths[1]) + + radixLevel3 := *(*radixLevel2).Children[branchPaths[1]] + depth += m.branchFactors[2] + if (radixLevel3).Children[branchPaths[2]] == nil { + node := &MediumRadixNode[L5]{Depth: depth} + (radixLevel3).Children[branchPaths[2]] = &node + } + (radixLevel3).InvalidateNode(branchPaths[2]) + + radixLevel4 := *(*radixLevel3).Children[branchPaths[2]] + depth += m.branchFactors[3] + if (radixLevel4).Children[branchPaths[3]] == nil { + node := &SmallRadixNode[L6]{Depth: depth} + (radixLevel4).Children[branchPaths[3]] = &node + } + (radixLevel4).InvalidateNode(branchPaths[3]) + + radixLevel5 := *(*radixLevel4).Children[branchPaths[3]] + depth += m.branchFactors[4] + if (radixLevel5).Children[branchPaths[4]] == nil { + node := &SmallRadixNode[L7]{Depth: depth} + (radixLevel5).Children[branchPaths[4]] = &node + } + (radixLevel5).InvalidateNode(branchPaths[4]) + + radixLevel6 := *(*radixLevel5).Children[branchPaths[4]] + (radixLevel6).Children[branchPaths[5]] = &m + (radixLevel6).InvalidateNode(branchPaths[5]) + + return p +} + +// Invalidate invalidates the cache along the path from the specified address up to the root. +// It ensures that any cached hashes are recomputed when needed. +func (m *Memory) Invalidate(addr uint64) { + // Find the page and invalidate the address within it. + if p, ok := m.pageLookup(addr >> PageAddrSize); ok { + prevValid := p.Ok[1] + if !prevValid { + // If the page was already invalid, the nodes up to the root are also invalid. + return + } + p.Invalidate(addr & PageAddrMask) + } else { + return + } + + branchPaths := m.addressToRadixPaths(addr) + + currentLevel1 := m.radix + currentLevel1.InvalidateNode(branchPaths[0]) + + radixLevel2 := (*m.radix).Children[branchPaths[0]] + if radixLevel2 == nil { + return + } + (*radixLevel2).InvalidateNode(branchPaths[1]) + + radixLevel3 := (*radixLevel2).Children[branchPaths[1]] + if radixLevel3 == nil { + return + } + (*radixLevel3).InvalidateNode(branchPaths[2]) + + radixLevel4 := (*radixLevel3).Children[branchPaths[2]] + if radixLevel4 == nil { + return + } + (*radixLevel4).InvalidateNode(branchPaths[3]) + + radixLevel5 := (*radixLevel4).Children[branchPaths[3]] + if radixLevel5 == nil { + return + } + (*radixLevel5).InvalidateNode(branchPaths[4]) + + radixLevel6 := (*radixLevel5).Children[branchPaths[4]] + if radixLevel6 == nil { + return + } + (*radixLevel6).InvalidateNode(branchPaths[5]) +}