-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create support for dynamic byte arrays
- Loading branch information
1 parent
19cbd66
commit 72ab19c
Showing
2 changed files
with
259 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
// SPDX-License-Identifier: MIT | ||
pragma solidity ^0.8.24; | ||
|
||
/// @title TransientContextBytes | ||
/// @notice Library for transient storage. | ||
library TransientContextBytes { | ||
error DataTooLarge(); | ||
error OutOfOrderSlots(); | ||
error RangeTooLarge(); | ||
|
||
/// @dev 4-bytes is way above current max contract size, meant to account for future EVM | ||
/// versions. | ||
uint256 internal constant LENGTH_MASK = 0xffffffff; | ||
uint256 internal constant MAX_LENGTH = LENGTH_MASK; | ||
uint256 internal constant LENGTH_BYTES = 4; | ||
|
||
/// @notice Slot for call depth. | ||
/// Equal to bytes32(uint256(keccak256("transient.calldepth")) - 1). | ||
bytes32 internal constant CALL_DEPTH_SLOT = 0x7a74fd168763fd280eaec3bcd2fd62d0e795027adc8183a693c497a7c2b10b5c; | ||
|
||
/// @notice Gets the call depth. | ||
/// @return callDepth_ Current call depth. | ||
function callDepth() internal view returns (uint256 callDepth_) { | ||
assembly ("memory-safe") { | ||
callDepth_ := tload(CALL_DEPTH_SLOT) | ||
} | ||
} | ||
|
||
/// @notice Gets bytes value in transient storage for a slot at the current call depth. | ||
/// @param _slot Slot to get. | ||
/// @return _value Transient bytes value. | ||
function get(bytes32 _slot) internal view returns (bytes memory _value) { | ||
assembly ("memory-safe") { | ||
// Allocate and load head. | ||
_value := mload(0x40) | ||
mstore(_value, 0) | ||
mstore(0, tload(CALL_DEPTH_SLOT)) | ||
mstore(32, _slot) | ||
let slot := keccak256(0, 64) | ||
mstore(add(_value, sub(0x20, LENGTH_BYTES)), tload(slot)) | ||
// Get length and update free pointer. | ||
let _valueStart := add(_value, 0x20) | ||
let len := mload(_value) | ||
mstore(0x40, add(_valueStart, len)) | ||
|
||
if gt(len, sub(0x20, LENGTH_BYTES)) { | ||
// Derive extended slots. | ||
mstore(0x00, slot) | ||
slot := keccak256(0x00, 0x20) | ||
|
||
// Store remainder. | ||
let offset := add(_valueStart, sub(0x20, LENGTH_BYTES)) | ||
let endOffset := add(_valueStart, len) | ||
for {} 1 {} { | ||
mstore(offset, tload(slot)) | ||
offset := add(offset, 0x20) | ||
if gt(offset, endOffset) { break } | ||
slot := add(slot, 1) | ||
} | ||
mstore(endOffset, 0) | ||
} | ||
} | ||
} | ||
|
||
/// @notice Sets a bytes value in transient storage for a slot at the current call depth. | ||
/// @param _slot Slot to set. | ||
/// @param _value Value to set. | ||
function set(bytes32 _slot, bytes memory _value) internal { | ||
assembly ("memory-safe") { | ||
let len := mload(_value) | ||
|
||
if gt(len, LENGTH_MASK) { | ||
mstore(0x00, 0x54ef47ee /* DataTooLarge() */ ) | ||
revert(0x1c, 0x04) | ||
} | ||
|
||
// Store first word packed with length | ||
let _valueStart := add(_value, 0x20) | ||
let head := mload(sub(_valueStart, LENGTH_BYTES)) | ||
|
||
mstore(0, tload(CALL_DEPTH_SLOT)) | ||
mstore(32, _slot) | ||
let slot := keccak256(0, 64) | ||
|
||
tstore(slot, head) | ||
|
||
if gt(len, sub(0x20, LENGTH_BYTES)) { | ||
// Derive extended slots. | ||
mstore(0x00, slot) | ||
slot := keccak256(0x00, 0x20) | ||
|
||
// Store remainder. | ||
let offset := add(_valueStart, sub(0x20, LENGTH_BYTES)) | ||
// Ensure each loop can do cheap comparison to see if it's at the end. | ||
let endOffset := sub(add(_valueStart, len), 1) | ||
for {} 1 {} { | ||
tstore(slot, mload(offset)) | ||
offset := add(offset, 0x20) | ||
if gt(offset, endOffset) { break } | ||
slot := add(slot, 1) | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// @notice Increments call depth. | ||
/// This function can overflow. However, this is ok because there's still | ||
/// only one value stored per slot. | ||
function increment() internal { | ||
assembly ("memory-safe") { | ||
tstore(CALL_DEPTH_SLOT, add(tload(CALL_DEPTH_SLOT), 1)) | ||
} | ||
} | ||
|
||
/// @notice Decrements call depth. | ||
/// This function can underflow. However, this is ok because there's still | ||
/// only one value stored per slot. | ||
function decrement() internal { | ||
assembly ("memory-safe") { | ||
tstore(CALL_DEPTH_SLOT, sub(tload(CALL_DEPTH_SLOT), 1)) | ||
} | ||
} | ||
} | ||
|
||
/// @title TransientReentrancyAware | ||
/// @notice Reentrancy-aware modifier for transient storage, which increments and | ||
/// decrements the call depth when entering and exiting a function. | ||
contract TransientReentrancyAware { | ||
/// @notice Modifier to make a function reentrancy-aware. | ||
modifier reentrantAware() { | ||
TransientContextBytes.increment(); | ||
_; | ||
TransientContextBytes.decrement(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
// SPDX-License-Identifier: MIT | ||
pragma solidity ^0.8.24; | ||
|
||
// Testing utilities | ||
import {Test} from "forge-std/Test.sol"; | ||
|
||
// Target contracts | ||
import {TransientContextBytes} from "src/TransientContextBytes.sol"; | ||
import {TransientReentrancyAware} from "src/TransientContextBytes.sol"; | ||
|
||
/// @title TransientContextBytesTest | ||
/// @notice Tests for the TransientContext library with bytes. | ||
contract TransientContextBytesTest is Test { | ||
/// @notice Slot for call depth. | ||
bytes32 internal callDepthSlot = bytes32(uint256(keccak256("transient.calldepth")) - 1); | ||
|
||
/// @notice Tests that `callDepth()` outputs the corrects call depth. | ||
/// @param _callDepth Call depth to test. | ||
function testFuzz_callDepth_succeeds(uint256 _callDepth) public { | ||
assembly ("memory-safe") { | ||
tstore(sload(callDepthSlot.slot), _callDepth) | ||
} | ||
assertEq(TransientContextBytes.callDepth(), _callDepth); | ||
} | ||
|
||
/// @notice Tests that `increment()` increments the call depth. | ||
/// @param _startingCallDepth Starting call depth. | ||
function testFuzz_increment_succeeds(uint256 _startingCallDepth) public { | ||
vm.assume(_startingCallDepth < type(uint256).max); | ||
assembly ("memory-safe") { | ||
tstore(sload(callDepthSlot.slot), _startingCallDepth) | ||
} | ||
assertEq(TransientContextBytes.callDepth(), _startingCallDepth); | ||
|
||
TransientContextBytes.increment(); | ||
assertEq(TransientContextBytes.callDepth(), _startingCallDepth + 1); | ||
} | ||
|
||
/// @notice Tests that `decrement()` decrements the call depth. | ||
/// @param _startingCallDepth Starting call depth. | ||
function testFuzz_decrement_succeeds(uint256 _startingCallDepth) public { | ||
vm.assume(_startingCallDepth > 0); | ||
assembly ("memory-safe") { | ||
tstore(sload(callDepthSlot.slot), _startingCallDepth) | ||
} | ||
assertEq(TransientContextBytes.callDepth(), _startingCallDepth); | ||
|
||
TransientContextBytes.decrement(); | ||
assertEq(TransientContextBytes.callDepth(), _startingCallDepth - 1); | ||
} | ||
|
||
/// @notice Tests that `get()` returns the correct value. | ||
/// @param _slot Slot to test. | ||
/// @param _value Value to test. | ||
function testFuzz_get_succeeds(bytes32 _slot, bytes calldata _value) public { | ||
bytes32 tslot = keccak256(abi.encodePacked(TransientContextBytes.callDepth(), _slot)); | ||
|
||
bytes memory emptyValue = TransientContextBytes.get(bytes32(0)); | ||
assertEq(TransientContextBytes.get(tslot), emptyValue); | ||
|
||
TransientContextBytes.set(tslot, _value); | ||
|
||
assertEq(TransientContextBytes.get(tslot), _value); | ||
} | ||
|
||
/// @notice Tests that `set()` sets the correct value. | ||
/// @param _slot Slot to test. | ||
/// @param _value Value to test. | ||
function testFuzz_set_succeeds(bytes32 _slot, bytes calldata _value) public { | ||
TransientContextBytes.set(_slot, _value); | ||
bytes32 tslot = keccak256(abi.encodePacked(TransientContextBytes.callDepth(), _slot)); | ||
bytes memory tvalue = TransientContextBytes.get(_slot); | ||
assertEq(tvalue, _value); | ||
} | ||
|
||
/// @notice Tests that `set()` and `get()` work together. | ||
/// @param _slot Slot to test. | ||
/// @param _value Value to test. | ||
function testFuzz_setGet_succeeds(bytes32 _slot, bytes calldata _value) public { | ||
testFuzz_set_succeeds(_slot, _value); | ||
assertEq(TransientContextBytes.get(_slot), _value); | ||
} | ||
|
||
/// @notice Tests that `set()` and `get()` work together at the same depth. | ||
/// @param _slot Slot to test. | ||
/// @param _value1 Value to write to slot at call depth 0. | ||
/// @param _value2 Value to write to slot at call depth 0. | ||
function testFuzz_setGet_twice_sameDepth_succeeds(bytes32 _slot, bytes calldata _value1, bytes calldata _value2) | ||
public | ||
{ | ||
assertEq(TransientContextBytes.callDepth(), 0); | ||
testFuzz_set_succeeds(_slot, _value1); | ||
assertEq(TransientContextBytes.get(_slot), _value1); | ||
|
||
assertEq(TransientContextBytes.callDepth(), 0); | ||
testFuzz_set_succeeds(_slot, _value2); | ||
assertEq(TransientContextBytes.get(_slot), _value2); | ||
} | ||
|
||
/// @notice Tests that `set()` and `get()` work together at different depths. | ||
/// @param _slot Slot to test. | ||
/// @param _value1 Value to write to slot at call depth 0. | ||
/// @param _value2 Value to write to slot at call depth 1. | ||
function testFuzz_setGet_twice_differentDepth_succeeds( | ||
bytes32 _slot, | ||
bytes calldata _value1, | ||
bytes calldata _value2 | ||
) public { | ||
assertEq(TransientContextBytes.callDepth(), 0); | ||
testFuzz_set_succeeds(_slot, _value1); | ||
assertEq(TransientContextBytes.get(_slot), _value1); | ||
|
||
TransientContextBytes.increment(); | ||
|
||
assertEq(TransientContextBytes.callDepth(), 1); | ||
testFuzz_set_succeeds(_slot, _value2); | ||
assertEq(TransientContextBytes.get(_slot), _value2); | ||
|
||
TransientContextBytes.decrement(); | ||
|
||
assertEq(TransientContextBytes.callDepth(), 0); | ||
assertEq(TransientContextBytes.get(_slot), _value1); | ||
} | ||
} |