From af011d2a0faa1876c43728b21e695e501b6f0af7 Mon Sep 17 00:00:00 2001 From: Danno Ferrin Date: Tue, 24 Oct 2023 23:11:15 -0600 Subject: [PATCH] Use Bytes Trie to track warm addresses (#6069) * Use Bytes Trie to track warm addresses Move from a java HashSet to a custom Trie based on bytes to store the warm addresses, creates, and self-destructs. This avoids needing to calculate java hashes or engage in using custom Comparators. Signed-off-by: Danno Ferrin * codeql scan Signed-off-by: Danno Ferrin --------- Signed-off-by: Danno Ferrin Signed-off-by: Sally MacFarlane Co-authored-by: Sally MacFarlane --- .../mainnet/MainnetTransactionProcessor.java | 4 +- .../besu/collections/trie/BytesTrieSet.java | 312 ++++++++++++++++++ .../besu/evm/fluent/EVMExecutor.java | 4 +- .../besu/evm/frame/MessageFrame.java | 8 +- .../evm/worldstate/AbstractWorldUpdater.java | 5 +- .../collections/trie/BytesTrieSetTest.java | 176 ++++++++++ 6 files changed, 499 insertions(+), 10 deletions(-) create mode 100644 evm/src/main/java/org/hyperledger/besu/collections/trie/BytesTrieSet.java create mode 100644 evm/src/test/java/org/hyperledger/besu/collections/trie/BytesTrieSetTest.java diff --git a/ethereum/core/src/main/java/org/hyperledger/besu/ethereum/mainnet/MainnetTransactionProcessor.java b/ethereum/core/src/main/java/org/hyperledger/besu/ethereum/mainnet/MainnetTransactionProcessor.java index cdeec19b117..d8b938a2fa1 100644 --- a/ethereum/core/src/main/java/org/hyperledger/besu/ethereum/mainnet/MainnetTransactionProcessor.java +++ b/ethereum/core/src/main/java/org/hyperledger/besu/ethereum/mainnet/MainnetTransactionProcessor.java @@ -19,6 +19,7 @@ import static org.hyperledger.besu.ethereum.mainnet.PrivateStateUtils.KEY_TRANSACTION; import static org.hyperledger.besu.ethereum.mainnet.PrivateStateUtils.KEY_TRANSACTION_HASH; +import org.hyperledger.besu.collections.trie.BytesTrieSet; import org.hyperledger.besu.datatypes.AccessListEntry; import org.hyperledger.besu.datatypes.Address; import org.hyperledger.besu.datatypes.Wei; @@ -43,7 +44,6 @@ import org.hyperledger.besu.evm.worldstate.WorldUpdater; import java.util.Deque; -import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; @@ -318,7 +318,7 @@ public TransactionProcessingResult processTransaction( final List accessListEntries = transaction.getAccessList().orElse(List.of()); // we need to keep a separate hash set of addresses in case they specify no storage. // No-storage is a common pattern, especially for Externally Owned Accounts - final Set
addressList = new HashSet<>(); + final Set
addressList = new BytesTrieSet<>(Address.SIZE); final Multimap storageList = HashMultimap.create(); int accessListStorageCount = 0; for (final var entry : accessListEntries) { diff --git a/evm/src/main/java/org/hyperledger/besu/collections/trie/BytesTrieSet.java b/evm/src/main/java/org/hyperledger/besu/collections/trie/BytesTrieSet.java new file mode 100644 index 00000000000..20e2d49af5c --- /dev/null +++ b/evm/src/main/java/org/hyperledger/besu/collections/trie/BytesTrieSet.java @@ -0,0 +1,312 @@ +/* + * Copyright contributors to Hyperledger Besu + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ +package org.hyperledger.besu.collections.trie; + +import java.util.AbstractSet; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Objects; + +import org.apache.tuweni.bytes.Bytes; + +/** + * A Bytes optimized set that stores values in a trie by byte + * + * @param Type of trie + */ +public class BytesTrieSet extends AbstractSet { + + record Node(byte[] leafArray, E leafObject, Node[] children) { + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (!(o instanceof Node node)) return false; + return Arrays.equals(leafArray, node.leafArray) + && Objects.equals(leafObject, node.leafObject) + && Arrays.equals(children, node.children); + } + + @Override + public int hashCode() { + int result = Objects.hash(leafObject); + result = 31 * result + Arrays.hashCode(leafArray); + result = 31 * result + Arrays.hashCode(children); + return result; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("Node{"); + sb.append("leaf="); + if (leafObject == null) sb.append("null"); + else { + sb.append('['); + System.out.println(leafObject.toHexString()); + sb.append(']'); + } + sb.append(", children="); + if (children == null) sb.append("null"); + else { + sb.append('['); + for (int i = 0; i < children.length; ++i) { + if (children[i] == null) { + continue; + } + sb.append(i == 0 ? "" : ", ").append(i).append("=").append(children[i]); + } + sb.append(']'); + } + sb.append('}'); + return sb.toString(); + } + } + + Node root; + + int size = 0; + final int byteLength; + + /** + * Create a BytesTrieSet with a fixed length + * + * @param byteLength length in bytes of the stored types + */ + public BytesTrieSet(final int byteLength) { + this.byteLength = byteLength; + } + + static class NodeWalker { + final Node node; + int lastRead; + + NodeWalker(final Node node) { + this.node = node; + this.lastRead = -1; + } + + NodeWalker nextNodeWalker() { + if (node.children == null) { + return null; + } + while (lastRead < 255) { + lastRead++; + Node child = node.children[lastRead]; + if (child != null) { + return new NodeWalker<>(child); + } + } + return null; + } + + E thisNode() { + return node.leafObject; + } + } + + @Override + public Iterator iterator() { + var result = + new Iterator() { + final Deque> stack = new ArrayDeque<>(); + E next; + E last; + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public E next() { + if (next == null) { + throw new NoSuchElementException(); + } + last = next; + advance(); + return last; + } + + @Override + public void remove() { + BytesTrieSet.this.remove(last); + } + + void advance() { + while (!stack.isEmpty()) { + NodeWalker thisStep = stack.peek(); + var nextStep = thisStep.nextNodeWalker(); + if (nextStep == null) { + stack.pop(); + if (thisStep.thisNode() != null) { + next = thisStep.thisNode(); + return; + } + } else { + stack.push(nextStep); + } + } + next = null; + } + }; + if (root != null) { + result.stack.add(new NodeWalker<>(root)); + } + result.advance(); + return result; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean contains(final Object o) { + if (!(o instanceof Bytes bytes)) { + throw new IllegalArgumentException( + "Expected Bytes, got " + (o == null ? "null" : o.getClass().getName())); + } + byte[] array = bytes.toArrayUnsafe(); + if (array.length != byteLength) { + throw new IllegalArgumentException( + "Byte array is size " + array.length + " but set is size " + byteLength); + } + if (root == null) { + return false; + } + int level = 0; + Node current = root; + while (current != null) { + if (current.leafObject != null) { + return Arrays.compare(current.leafArray, array) == 0; + } + current = current.children[array[level] & 0xff]; + level++; + } + return false; + } + + @Override + public boolean remove(final Object o) { + // Two base cases, size==0 and size==1; + if (!(o instanceof Bytes bytes)) { + throw new IllegalArgumentException( + "Expected Bytes, got " + (o == null ? "null" : o.getClass().getName())); + } + byte[] array = bytes.toArrayUnsafe(); + if (array.length != byteLength) { + throw new IllegalArgumentException( + "Byte array is size " + array.length + " but set is size " + byteLength); + } + // Two base cases, size==0 and size==1; + if (root == null) { + // size==0 is easy, empty + return false; + } + if (root.leafObject != null) { + // size==1 just check and possibly remove the root + if (Arrays.compare(array, root.leafArray) == 0) { + root = null; + size--; + return true; + } else { + return false; + } + } + int level = 0; + Node current = root; + do { + int index = array[level] & 0xff; + Node next = current.children[index]; + if (next == null) { + return false; + } + if (next.leafObject != null) { + if (Arrays.compare(array, next.leafArray) == 0) { + // TODO there is no cleanup of empty branches + current.children[index] = null; + size--; + return true; + } else { + return false; + } + } + current = next; + + level++; + } while (true); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public boolean add(final E bytes) { + byte[] array = bytes.toArrayUnsafe(); + if (array.length != byteLength) { + throw new IllegalArgumentException( + "Byte array is size " + array.length + " but set is size " + byteLength); + } + // Two base cases, size==0 and size==1; + if (root == null) { + // size==0 is easy, just add + root = new Node<>(array, bytes, null); + size++; + return true; + } + if (root.leafObject != null) { + // size==1 first check then if no match make it look like n>1 + if (Arrays.compare(array, root.leafArray) == 0) { + return false; + } + Node oldRoot = root; + root = new Node<>(null, null, new Node[256]); + root.children[oldRoot.leafArray[0] & 0xff] = oldRoot; + } + int level = 0; + Node current = root; + do { + int index = array[level] & 0xff; + Node next = current.children[index]; + if (next == null) { + next = new Node<>(array, bytes, null); + current.children[index] = next; + size++; + return true; + } + if (next.leafObject != null) { + if (Arrays.compare(array, next.leafArray) == 0) { + return false; + } + Node newLeaf = new Node<>(null, null, new Node[256]); + newLeaf.children[next.leafArray[level + 1] & 0xff] = next; + current.children[index] = newLeaf; + next = newLeaf; + } + level++; + + current = next; + + } while (true); + } + + @Override + public void clear() { + root = null; + } +} diff --git a/evm/src/main/java/org/hyperledger/besu/evm/fluent/EVMExecutor.java b/evm/src/main/java/org/hyperledger/besu/evm/fluent/EVMExecutor.java index d046c64778e..9946b6f57a8 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/fluent/EVMExecutor.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/fluent/EVMExecutor.java @@ -16,6 +16,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import org.hyperledger.besu.collections.trie.BytesTrieSet; import org.hyperledger.besu.datatypes.Address; import org.hyperledger.besu.datatypes.Hash; import org.hyperledger.besu.datatypes.VersionedHash; @@ -41,7 +42,6 @@ import java.math.BigInteger; import java.util.Collection; import java.util.Deque; -import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -80,7 +80,7 @@ public class EVMExecutor { List.of(MaxCodeSizeRule.of(0x6000), PrefixCodeRule.of()); private long initialNonce = 1; private Collection
forceCommitAddresses = List.of(Address.fromHexString("0x03")); - private Set
accessListWarmAddresses = new HashSet<>(); + private Set
accessListWarmAddresses = new BytesTrieSet<>(Address.SIZE); private Multimap accessListWarmStorage = HashMultimap.create(); private MessageCallProcessor messageCallProcessor = null; private ContractCreationProcessor contractCreationProcessor = null; diff --git a/evm/src/main/java/org/hyperledger/besu/evm/frame/MessageFrame.java b/evm/src/main/java/org/hyperledger/besu/evm/frame/MessageFrame.java index ed711e3f181..64c865dfd41 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/frame/MessageFrame.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/frame/MessageFrame.java @@ -17,6 +17,7 @@ import static com.google.common.base.Preconditions.checkState; import static java.util.Collections.emptySet; +import org.hyperledger.besu.collections.trie.BytesTrieSet; import org.hyperledger.besu.collections.undo.UndoSet; import org.hyperledger.besu.collections.undo.UndoTable; import org.hyperledger.besu.datatypes.Address; @@ -38,7 +39,6 @@ import java.util.ArrayList; import java.util.Deque; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -1707,7 +1707,7 @@ public MessageFrame build() { new TxValues( blockHashLookup, maxStackSize, - UndoSet.of(new HashSet<>()), + UndoSet.of(new BytesTrieSet<>(Address.SIZE)), UndoTable.of(HashBasedTable.create()), originator, gasPrice, @@ -1717,8 +1717,8 @@ public MessageFrame build() { miningBeneficiary, versionedHashes, UndoTable.of(HashBasedTable.create()), - UndoSet.of(new HashSet<>()), - UndoSet.of(new HashSet<>())); + UndoSet.of(new BytesTrieSet<>(Address.SIZE)), + UndoSet.of(new BytesTrieSet<>(Address.SIZE))); updater = worldUpdater; newStatic = isStatic; } else { diff --git a/evm/src/main/java/org/hyperledger/besu/evm/worldstate/AbstractWorldUpdater.java b/evm/src/main/java/org/hyperledger/besu/evm/worldstate/AbstractWorldUpdater.java index 41cacdbea06..2eb48dd1add 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/worldstate/AbstractWorldUpdater.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/worldstate/AbstractWorldUpdater.java @@ -14,6 +14,7 @@ */ package org.hyperledger.besu.evm.worldstate; +import org.hyperledger.besu.collections.trie.BytesTrieSet; import org.hyperledger.besu.datatypes.Address; import org.hyperledger.besu.datatypes.Wei; import org.hyperledger.besu.evm.account.Account; @@ -21,7 +22,6 @@ import java.util.Collection; import java.util.Collections; -import java.util.HashSet; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -44,7 +44,8 @@ public abstract class AbstractWorldUpdater> updatedAccounts = new ConcurrentHashMap<>(); /** The Deleted accounts. */ - protected Set
deletedAccounts = Collections.synchronizedSet(new HashSet<>()); + protected Set
deletedAccounts = + Collections.synchronizedSet(new BytesTrieSet<>(Address.SIZE)); /** * Instantiates a new Abstract world updater. diff --git a/evm/src/test/java/org/hyperledger/besu/collections/trie/BytesTrieSetTest.java b/evm/src/test/java/org/hyperledger/besu/collections/trie/BytesTrieSetTest.java new file mode 100644 index 00000000000..dc92cc47d72 --- /dev/null +++ b/evm/src/test/java/org/hyperledger/besu/collections/trie/BytesTrieSetTest.java @@ -0,0 +1,176 @@ +/* + * Copyright contributors to Hyperledger Besu + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + * + */ +package org.hyperledger.besu.collections.trie; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.apache.tuweni.bytes.Bytes; +import org.junit.jupiter.api.Test; + +class BytesTrieSetTest { + + private static final Bytes BYTES_1234 = Bytes.of(1, 2, 3, 4); + private static final Bytes BYTES_4321 = Bytes.of(4, 3, 2, 1); + private static final Bytes BYTES_4567 = Bytes.of(4, 5, 6, 7); + private static final Bytes BYTES_4568 = Bytes.of(4, 5, 6, 8); + private static final Bytes BYTES_4556 = Bytes.of(4, 5, 5, 6); + private static final Bytes BYTES_123 = Bytes.of(1, 2, 3); + + @Test + void testInserts() { + BytesTrieSet trieSet = new BytesTrieSet<>(4); + assertThat(trieSet).isEmpty(); + System.out.println(trieSet); + + assertThat(trieSet.add(BYTES_1234)).isTrue(); + assertThat(trieSet).hasSize(1); + + assertThat(trieSet.add(BYTES_1234)).isFalse(); + assertThat(trieSet).hasSize(1); + + assertThat(trieSet.add(BYTES_4321)).isTrue(); + assertThat(trieSet).hasSize(2); + + assertThat(trieSet.add(BYTES_4567)).isTrue(); + assertThat(trieSet).hasSize(3); + + assertThat(trieSet.add(BYTES_4567)).isFalse(); + assertThat(trieSet).hasSize(3); + + System.out.println(trieSet); + } + + @Test + void testRemoves() { + BytesTrieSet trieSet = new BytesTrieSet<>(4); + + assertThat(trieSet.remove(BYTES_1234)).isFalse(); + + trieSet.add(BYTES_1234); + assertThat(trieSet.remove(BYTES_4321)).isFalse(); + assertThat(trieSet.remove(BYTES_1234)).isTrue(); + assertThat(trieSet).isEmpty(); + + trieSet.add(BYTES_1234); + trieSet.add(BYTES_4321); + assertThat(trieSet.remove(BYTES_4567)).isFalse(); + assertThat(trieSet.remove(BYTES_4568)).isFalse(); + + trieSet.add(BYTES_4567); + trieSet.add(BYTES_4568); + assertThat(trieSet).hasSize(4); + + assertThat(trieSet.remove(BYTES_4556)).isFalse(); + assertThat(trieSet.remove(BYTES_4568)).isTrue(); + assertThat(trieSet.remove(BYTES_4568)).isFalse(); + assertThat(trieSet).hasSize(3); + assertThat(trieSet.remove(BYTES_4567)).isTrue(); + assertThat(trieSet).hasSize(2); + + assertThat(trieSet.remove(BYTES_4321)).isTrue(); + assertThat(trieSet).hasSize(1); + + assertThat(trieSet.remove(BYTES_1234)).isTrue(); + assertThat(trieSet).isEmpty(); + } + + @Test + @SuppressWarnings( + "squid:S5838") // contains and doesNotContains uses iterables, not the contains method + void testContains() { + BytesTrieSet trieSet = new BytesTrieSet<>(4); + + assertThat(trieSet.contains(BYTES_1234)).isFalse(); + + trieSet.add(BYTES_1234); + assertThat(trieSet.contains(BYTES_4321)).isFalse(); + assertThat(trieSet.contains(BYTES_1234)).isTrue(); + assertThat(trieSet).hasSize(1); + + trieSet.add(BYTES_1234); + trieSet.add(BYTES_4321); + assertThat(trieSet.contains(BYTES_4567)).isFalse(); + assertThat(trieSet.contains(BYTES_4568)).isFalse(); + + trieSet.add(BYTES_4567); + trieSet.add(BYTES_4568); + assertThat(trieSet).hasSize(4); + + assertThat(trieSet.contains(BYTES_4556)).isFalse(); + assertThat(trieSet.contains(BYTES_4568)).isTrue(); + trieSet.remove(BYTES_4568); + assertThat(trieSet).hasSize(3); + assertThat(trieSet.contains(BYTES_4567)).isTrue(); + trieSet.remove(BYTES_4567); + assertThat(trieSet).hasSize(2); + assertThat(trieSet.contains(BYTES_4567)).isFalse(); + + assertThat(trieSet.contains(BYTES_4321)).isTrue(); + trieSet.remove(BYTES_4321); + assertThat(trieSet.contains(BYTES_4321)).isFalse(); + assertThat(trieSet).hasSize(1); + + assertThat(trieSet.contains(BYTES_1234)).isTrue(); + trieSet.remove(BYTES_1234); + assertThat(trieSet.contains(BYTES_4321)).isFalse(); + + assertThat(trieSet).isEmpty(); + } + + @Test + @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") + void checkLengthAdd() { + BytesTrieSet trieSet = new BytesTrieSet<>(4); + assertThatThrownBy(() -> trieSet.add(BYTES_123)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") + void checkLengthRemove() { + BytesTrieSet trieSet = new BytesTrieSet<>(4); + assertThatThrownBy(() -> trieSet.remove(BYTES_123)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("4"); + } + + @Test + @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") + void checkLengthContains() { + BytesTrieSet trieSet = new BytesTrieSet<>(4); + assertThatThrownBy(() -> trieSet.contains(BYTES_123)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("4"); + } + + @Test + @SuppressWarnings({"MismatchedQueryAndUpdateOfCollection", "SuspiciousMethodCalls"}) + void checkWrongClassRemove() { + BytesTrieSet trieSet = new BytesTrieSet<>(4); + assertThatThrownBy(() -> trieSet.remove(this)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Bytes"); + } + + @Test + @SuppressWarnings({"MismatchedQueryAndUpdateOfCollection", "SuspiciousMethodCalls"}) + void checkWrongClassContains() { + BytesTrieSet trieSet = new BytesTrieSet<>(4); + assertThatThrownBy(() -> trieSet.contains(this)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Bytes"); + } +}