Skip to content

Commit

Permalink
create support for dynamic byte arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
0xfuturistic committed Jul 9, 2024
1 parent 19cbd66 commit 72ab19c
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 0 deletions.
135 changes: 135 additions & 0 deletions src/TransientContextBytes.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.24;

/// @title TransientContextBytes
/// @notice Library for transient storage.
library TransientContextBytes {
error DataTooLarge();
error OutOfOrderSlots();
error RangeTooLarge();

/// @dev 4-bytes is way above current max contract size, meant to account for future EVM
/// versions.
uint256 internal constant LENGTH_MASK = 0xffffffff;
uint256 internal constant MAX_LENGTH = LENGTH_MASK;
uint256 internal constant LENGTH_BYTES = 4;

/// @notice Slot for call depth.
/// Equal to bytes32(uint256(keccak256("transient.calldepth")) - 1).
bytes32 internal constant CALL_DEPTH_SLOT = 0x7a74fd168763fd280eaec3bcd2fd62d0e795027adc8183a693c497a7c2b10b5c;

/// @notice Gets the call depth.
/// @return callDepth_ Current call depth.
function callDepth() internal view returns (uint256 callDepth_) {
assembly ("memory-safe") {
callDepth_ := tload(CALL_DEPTH_SLOT)
}
}

/// @notice Gets bytes value in transient storage for a slot at the current call depth.
/// @param _slot Slot to get.
/// @return _value Transient bytes value.
function get(bytes32 _slot) internal view returns (bytes memory _value) {
assembly ("memory-safe") {
// Allocate and load head.
_value := mload(0x40)
mstore(_value, 0)
mstore(0, tload(CALL_DEPTH_SLOT))
mstore(32, _slot)
let slot := keccak256(0, 64)
mstore(add(_value, sub(0x20, LENGTH_BYTES)), tload(slot))
// Get length and update free pointer.
let _valueStart := add(_value, 0x20)
let len := mload(_value)
mstore(0x40, add(_valueStart, len))

if gt(len, sub(0x20, LENGTH_BYTES)) {
// Derive extended slots.
mstore(0x00, slot)
slot := keccak256(0x00, 0x20)

// Store remainder.
let offset := add(_valueStart, sub(0x20, LENGTH_BYTES))
let endOffset := add(_valueStart, len)
for {} 1 {} {
mstore(offset, tload(slot))
offset := add(offset, 0x20)
if gt(offset, endOffset) { break }
slot := add(slot, 1)
}
mstore(endOffset, 0)
}
}
}

/// @notice Sets a bytes value in transient storage for a slot at the current call depth.
/// @param _slot Slot to set.
/// @param _value Value to set.
function set(bytes32 _slot, bytes memory _value) internal {
assembly ("memory-safe") {
let len := mload(_value)

if gt(len, LENGTH_MASK) {
mstore(0x00, 0x54ef47ee /* DataTooLarge() */ )
revert(0x1c, 0x04)
}

// Store first word packed with length
let _valueStart := add(_value, 0x20)
let head := mload(sub(_valueStart, LENGTH_BYTES))

mstore(0, tload(CALL_DEPTH_SLOT))
mstore(32, _slot)
let slot := keccak256(0, 64)

tstore(slot, head)

if gt(len, sub(0x20, LENGTH_BYTES)) {
// Derive extended slots.
mstore(0x00, slot)
slot := keccak256(0x00, 0x20)

// Store remainder.
let offset := add(_valueStart, sub(0x20, LENGTH_BYTES))
// Ensure each loop can do cheap comparison to see if it's at the end.
let endOffset := sub(add(_valueStart, len), 1)
for {} 1 {} {
tstore(slot, mload(offset))
offset := add(offset, 0x20)
if gt(offset, endOffset) { break }
slot := add(slot, 1)
}
}
}
}

/// @notice Increments call depth.
/// This function can overflow. However, this is ok because there's still
/// only one value stored per slot.
function increment() internal {
assembly ("memory-safe") {
tstore(CALL_DEPTH_SLOT, add(tload(CALL_DEPTH_SLOT), 1))
}
}

/// @notice Decrements call depth.
/// This function can underflow. However, this is ok because there's still
/// only one value stored per slot.
function decrement() internal {
assembly ("memory-safe") {
tstore(CALL_DEPTH_SLOT, sub(tload(CALL_DEPTH_SLOT), 1))
}
}
}

/// @title TransientReentrancyAware
/// @notice Reentrancy-aware modifier for transient storage, which increments and
/// decrements the call depth when entering and exiting a function.
contract TransientReentrancyAware {
/// @notice Modifier to make a function reentrancy-aware.
modifier reentrantAware() {
TransientContextBytes.increment();
_;
TransientContextBytes.decrement();
}
}
124 changes: 124 additions & 0 deletions test/TransientContextBytes.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.24;

// Testing utilities
import {Test} from "forge-std/Test.sol";

// Target contracts
import {TransientContextBytes} from "src/TransientContextBytes.sol";
import {TransientReentrancyAware} from "src/TransientContextBytes.sol";

/// @title TransientContextBytesTest
/// @notice Tests for the TransientContext library with bytes.
contract TransientContextBytesTest is Test {
/// @notice Slot for call depth.
bytes32 internal callDepthSlot = bytes32(uint256(keccak256("transient.calldepth")) - 1);

/// @notice Tests that `callDepth()` outputs the corrects call depth.
/// @param _callDepth Call depth to test.
function testFuzz_callDepth_succeeds(uint256 _callDepth) public {
assembly ("memory-safe") {
tstore(sload(callDepthSlot.slot), _callDepth)
}
assertEq(TransientContextBytes.callDepth(), _callDepth);
}

/// @notice Tests that `increment()` increments the call depth.
/// @param _startingCallDepth Starting call depth.
function testFuzz_increment_succeeds(uint256 _startingCallDepth) public {
vm.assume(_startingCallDepth < type(uint256).max);
assembly ("memory-safe") {
tstore(sload(callDepthSlot.slot), _startingCallDepth)
}
assertEq(TransientContextBytes.callDepth(), _startingCallDepth);

TransientContextBytes.increment();
assertEq(TransientContextBytes.callDepth(), _startingCallDepth + 1);
}

/// @notice Tests that `decrement()` decrements the call depth.
/// @param _startingCallDepth Starting call depth.
function testFuzz_decrement_succeeds(uint256 _startingCallDepth) public {
vm.assume(_startingCallDepth > 0);
assembly ("memory-safe") {
tstore(sload(callDepthSlot.slot), _startingCallDepth)
}
assertEq(TransientContextBytes.callDepth(), _startingCallDepth);

TransientContextBytes.decrement();
assertEq(TransientContextBytes.callDepth(), _startingCallDepth - 1);
}

/// @notice Tests that `get()` returns the correct value.
/// @param _slot Slot to test.
/// @param _value Value to test.
function testFuzz_get_succeeds(bytes32 _slot, bytes calldata _value) public {
bytes32 tslot = keccak256(abi.encodePacked(TransientContextBytes.callDepth(), _slot));

bytes memory emptyValue = TransientContextBytes.get(bytes32(0));
assertEq(TransientContextBytes.get(tslot), emptyValue);

TransientContextBytes.set(tslot, _value);

assertEq(TransientContextBytes.get(tslot), _value);
}

/// @notice Tests that `set()` sets the correct value.
/// @param _slot Slot to test.
/// @param _value Value to test.
function testFuzz_set_succeeds(bytes32 _slot, bytes calldata _value) public {
TransientContextBytes.set(_slot, _value);
bytes32 tslot = keccak256(abi.encodePacked(TransientContextBytes.callDepth(), _slot));
bytes memory tvalue = TransientContextBytes.get(_slot);
assertEq(tvalue, _value);
}

/// @notice Tests that `set()` and `get()` work together.
/// @param _slot Slot to test.
/// @param _value Value to test.
function testFuzz_setGet_succeeds(bytes32 _slot, bytes calldata _value) public {
testFuzz_set_succeeds(_slot, _value);
assertEq(TransientContextBytes.get(_slot), _value);
}

/// @notice Tests that `set()` and `get()` work together at the same depth.
/// @param _slot Slot to test.
/// @param _value1 Value to write to slot at call depth 0.
/// @param _value2 Value to write to slot at call depth 0.
function testFuzz_setGet_twice_sameDepth_succeeds(bytes32 _slot, bytes calldata _value1, bytes calldata _value2)
public
{
assertEq(TransientContextBytes.callDepth(), 0);
testFuzz_set_succeeds(_slot, _value1);
assertEq(TransientContextBytes.get(_slot), _value1);

assertEq(TransientContextBytes.callDepth(), 0);
testFuzz_set_succeeds(_slot, _value2);
assertEq(TransientContextBytes.get(_slot), _value2);
}

/// @notice Tests that `set()` and `get()` work together at different depths.
/// @param _slot Slot to test.
/// @param _value1 Value to write to slot at call depth 0.
/// @param _value2 Value to write to slot at call depth 1.
function testFuzz_setGet_twice_differentDepth_succeeds(
bytes32 _slot,
bytes calldata _value1,
bytes calldata _value2
) public {
assertEq(TransientContextBytes.callDepth(), 0);
testFuzz_set_succeeds(_slot, _value1);
assertEq(TransientContextBytes.get(_slot), _value1);

TransientContextBytes.increment();

assertEq(TransientContextBytes.callDepth(), 1);
testFuzz_set_succeeds(_slot, _value2);
assertEq(TransientContextBytes.get(_slot), _value2);

TransientContextBytes.decrement();

assertEq(TransientContextBytes.callDepth(), 0);
assertEq(TransientContextBytes.get(_slot), _value1);
}
}

0 comments on commit 72ab19c

Please sign in to comment.