diff --git a/src/TheCompact.sol b/src/TheCompact.sol index f630ea8..286fb8d 100644 --- a/src/TheCompact.sol +++ b/src/TheCompact.sol @@ -85,6 +85,7 @@ import { MetadataRenderer } from "./lib/MetadataRenderer.sol"; contract TheCompact is ITheCompact, ERC6909, Extsload { using HashLib for address; using HashLib for bytes32; + using HashLib for uint256; using HashLib for BasicTransfer; using HashLib for SplitTransfer; using HashLib for Claim; @@ -704,7 +705,7 @@ contract TheCompact is ITheCompact, ERC6909, Extsload { claimPayload.expires.later(); bytes32 exogenousDomainSeparator = - messageHash.toNotarizedDomainHash(claimPayload.notarizedChainId); + claimPayload.notarizedChainId.toNotarizedDomainSeparator(); messageHash.signedBy( claimPayload.sponsor, claimPayload.sponsorSignature, exogenousDomainSeparator ); @@ -951,6 +952,7 @@ contract TheCompact is ITheCompact, ERC6909, Extsload { function(address, address, uint256, uint256) internal returns (bool) operation ) internal returns (bool) { bytes32 messageHash = claimPayload.toMessageHash(); + _notExpiredAndWithValidSignatures.usingMultichainClaim()( messageHash, claimPayload, diff --git a/src/lib/ConsumerLib.sol b/src/lib/ConsumerLib.sol index 1b82346..748373e 100644 --- a/src/lib/ConsumerLib.sol +++ b/src/lib/ConsumerLib.sol @@ -24,7 +24,7 @@ library ConsumerLib { let bit := shl(and(0xff, nonce), 1) if and(bit, bucketValue) { // `InvalidNonce(address,uint256)` with padding for `account`. - mstore(0x0c, 0x8baa579f000000000000000000000000) + mstore(0x0c, 0xdbc205b1000000000000000000000000) revert(0x1c, 0x44) } diff --git a/src/lib/HashLib.sol b/src/lib/HashLib.sol index 7c82144..83435e9 100644 --- a/src/lib/HashLib.sol +++ b/src/lib/HashLib.sol @@ -576,10 +576,13 @@ library HashLib { let additionalChainsPtr := add(claim, calldataload(add(claim, 0xa0))) let additionalChainsLength := shl(5, calldataload(additionalChainsPtr)) calldatacopy(add(m, 0x20), add(0x20, additionalChainsPtr), additionalChainsLength) - mstore(0x80, keccak256(m, add(0x20, additionalChainsLength))) + + // hash of allocation hashes + mstore(add(m, 0x80), keccak256(m, add(0x20, additionalChainsLength))) mstore(m, MULTICHAIN_COMPACT_TYPEHASH) calldatacopy(add(m, 0x20), add(claim, 0x40), 0x60) // sponsor, nonce, expires + messageHash := keccak256(m, 0xa0) } } @@ -589,6 +592,8 @@ library HashLib { view returns (bytes32 messageHash) { + bytes32 allocationHash; + assembly ("memory-safe") { let m := mload(0x40) // Grab the free memory pointer; memory will be left dirtied. @@ -599,35 +604,35 @@ library HashLib { mstore(0, ALLOCATION_TYPEHASH) mstore(0x20, caller()) // arbiter mstore(0x40, chainid()) - let allocationHash := keccak256(0, 0x80) // allocation hash + allocationHash := keccak256(0, 0x80) // allocation hash mstore(0x40, m) mstore(0x60, 0) + } - let chainIndex := shl(5, calldataload(add(claim, 0xc0))) - - // all allocation hashes - let additionalChainsPtr := add(claim, calldataload(add(claim, 0xa0))) - let allChainsLength := add(0x20, shl(5, calldataload(additionalChainsPtr))) - - // TODO: likely a better way to do this; figure it out - let reductionOnceLocated := 0 - for { let i := 0 } lt(i, allChainsLength) { i := add(i, 0x20) } { - mstore( - add(m, i), - calldataload(add(sub(i, reductionOnceLocated), add(additionalChainsPtr, 0x20))) - ) - if eq(i, chainIndex) { - reductionOnceLocated := 0x20 - i := add(i, 0x20) - mstore(add(m, i), allocationHash) - } + // TODO: use inline assembly for this + uint256 additionalChainsLength = claim.additionalChains.length; + bytes32[] memory allocationHashes = new bytes32[](additionalChainsLength + 1); + uint256 extraOffset = 0; + for (uint256 i = 0; i < additionalChainsLength; ++i) { + allocationHashes[i] = claim.additionalChains[i + extraOffset]; + if (i == claim.chainIndex) { + extraOffset = 1; + allocationHashes[i + 1] = allocationHash; } + } + + bytes32 allocationHashesHash = keccak256(abi.encodePacked(allocationHashes)); + + assembly ("memory-safe") { + let m := mload(0x40) // Grab the free memory pointer; memory will be left dirtied. - mstore(0x80, keccak256(m, allChainsLength)) + // hash of allocation hashes + mstore(add(m, 0x80), allocationHashesHash) mstore(m, MULTICHAIN_COMPACT_TYPEHASH) calldatacopy(add(m, 0x20), add(claim, 0x40), 0x60) // sponsor, nonce, expires + messageHash := keccak256(m, 0xa0) } } @@ -743,30 +748,19 @@ library HashLib { } } - function toNotarizedDomainHash(bytes32 messageHash, uint256 notarizedChainId) + function toNotarizedDomainSeparator(uint256 notarizedChainId) internal view - returns (bytes32 domainHash) + returns (bytes32 notarizedDomainSeparator) { assembly ("memory-safe") { let m := mload(0x40) // Grab the free memory pointer. - - // Prepare the 712 prefix. - mstore(0, 0x1901) - - // Prepare the domain separator. mstore(m, _DOMAIN_TYPEHASH) mstore(add(m, 0x20), _NAME_HASH) mstore(add(m, 0x40), _VERSION_HASH) mstore(add(m, 0x60), notarizedChainId) mstore(add(m, 0x80), address()) - mstore(0x20, keccak256(m, 0xa0)) - - // Prepare the message hash and compute the domain hash. - mstore(0x40, messageHash) - domainHash := keccak256(0x1e, 0x42) - - mstore(0x40, m) // Restore the free memory pointer. + notarizedDomainSeparator := keccak256(m, 0xa0) } } diff --git a/test/TheCompact.t.sol b/test/TheCompact.t.sol index 2c65318..73ecdf3 100644 --- a/test/TheCompact.t.sol +++ b/test/TheCompact.t.sol @@ -9,6 +9,8 @@ import { ResetPeriod } from "../src/types/ResetPeriod.sol"; import { Scope } from "../src/types/Scope.sol"; import { ISignatureTransfer } from "permit2/src/interfaces/ISignatureTransfer.sol"; +import { HashLib } from "../src/lib/HashLib.sol"; + import { BasicTransfer, SplitTransfer, @@ -34,6 +36,8 @@ import { QualifiedSplitBatchClaimWithWitness } from "../src/types/BatchClaims.sol"; +import { MultichainClaim, ExogenousMultichainClaim } from "../src/types/MultichainClaims.sol"; + import { SplitComponent, TransferComponent, @@ -2737,4 +2741,156 @@ contract TheCompactTest is Test { assertEq(theCompact.balanceOf(recipientOne, anotherId), anotherAmount); assertEq(theCompact.balanceOf(recipientTwo, aThirdId), aThirdAmount); } + + function test_multichainClaim() public { + console.log("chainID", block.chainid); + + ResetPeriod resetPeriod = ResetPeriod.TenMinutes; + Scope scope = Scope.Multichain; + uint256 amount = 1e18; + uint256 anotherAmount = 1e18; + uint256 nonce = 0; + uint256 expires = block.timestamp + 1000; + address claimant = 0x1111111111111111111111111111111111111111; + address arbiter = 0x2222222222222222222222222222222222222222; + uint256 anotherChainId = 7171717; + + vm.prank(allocator); + theCompact.__register(allocator, ""); + + vm.startPrank(swapper); + uint256 id = theCompact.deposit{ value: amount }(allocator, resetPeriod, scope, swapper); + uint256 anotherId = theCompact.deposit( + address(token), + allocator, + ResetPeriod.TenMinutes, + Scope.Multichain, + anotherAmount, + swapper + ); + vm.stopPrank(); + + assertEq(theCompact.balanceOf(swapper, id), amount); + assertEq(theCompact.balanceOf(swapper, anotherId), anotherAmount); + + uint256[2][] memory idsAndAmountsOne = new uint256[2][](1); + idsAndAmountsOne[0] = [id, amount]; + + uint256[2][] memory idsAndAmountsTwo = new uint256[2][](1); + idsAndAmountsTwo[0] = [anotherId, anotherAmount]; + + bytes32 allocationHashOne = keccak256( + abi.encode( + keccak256("Allocation(address arbiter,uint256 chainId,uint256[2][] idsAndAmounts)"), + arbiter, + block.chainid, + keccak256(abi.encodePacked(idsAndAmountsOne)) + ) + ); + + bytes32 allocationHashTwo = keccak256( + abi.encode( + keccak256("Allocation(address arbiter,uint256 chainId,uint256[2][] idsAndAmounts)"), + arbiter, + anotherChainId, + keccak256(abi.encodePacked(idsAndAmountsTwo)) + ) + ); + + bytes32 claimHash = keccak256( + abi.encode( + keccak256( + "MultichainCompact(address sponsor,uint256 nonce,uint256 expires,Allocation[] allocations)Allocation(address arbiter,uint256 chainId,uint256[2][] idsAndAmounts)" + ), + swapper, + nonce, + expires, + keccak256(abi.encodePacked(allocationHashOne, allocationHashTwo)) + ) + ); + + bytes32 initialDomainSeparator = theCompact.DOMAIN_SEPARATOR(); + + bytes32 digest = + keccak256(abi.encodePacked(bytes2(0x1901), initialDomainSeparator, claimHash)); + + (bytes32 r, bytes32 vs) = vm.signCompact(swapperPrivateKey, digest); + bytes memory sponsorSignature = abi.encodePacked(r, vs); + + (r, vs) = vm.signCompact(allocatorPrivateKey, digest); + bytes memory allocatorSignature = abi.encodePacked(r, vs); + + bytes32[] memory additionalChains = new bytes32[](1); + additionalChains[0] = allocationHashTwo; + + MultichainClaim memory claim = MultichainClaim( + allocatorSignature, + sponsorSignature, + swapper, + nonce, + expires, + additionalChains, + id, + amount, + claimant, + amount + ); + + uint256 snapshotId = vm.snapshot(); + vm.prank(arbiter); + (bool status) = theCompact.claim(claim); + assert(status); + + assertEq(address(theCompact).balance, amount); + assertEq(claimant.balance, 0); + assertEq(theCompact.balanceOf(swapper, id), 0); + assertEq(theCompact.balanceOf(claimant, id), amount); + vm.revertToAndDelete(snapshotId); + + // change to "new chain" (this hack is so the original one gets stored) + uint256 notarizedChainId = abi.decode(abi.encode(block.chainid), (uint256)); + assert(notarizedChainId != anotherChainId); + vm.chainId(anotherChainId); + assertEq(block.chainid, anotherChainId); + assert(notarizedChainId != anotherChainId); + + bytes32 anotherDomainSeparator = theCompact.DOMAIN_SEPARATOR(); + + assert(initialDomainSeparator != anotherDomainSeparator); + + digest = keccak256(abi.encodePacked(bytes2(0x1901), anotherDomainSeparator, claimHash)); + + (r, vs) = vm.signCompact(allocatorPrivateKey, digest); + bytes memory exogenousAllocatorSignature = abi.encodePacked(r, vs); + + additionalChains[0] = allocationHashOne; + uint256 chainIndex = 0; + + ExogenousMultichainClaim memory anotherClaim = ExogenousMultichainClaim( + exogenousAllocatorSignature, + sponsorSignature, + swapper, + nonce, + expires, + additionalChains, + chainIndex, + notarizedChainId, + anotherId, + anotherAmount, + claimant, + anotherAmount + ); + + vm.prank(arbiter); + (bool exogenousStatus) = theCompact.claim(anotherClaim); + assert(exogenousStatus); + + assertEq(theCompact.balanceOf(swapper, anotherId), 0); + assertEq(theCompact.balanceOf(claimant, anotherId), anotherAmount); + + // change back + vm.chainId(notarizedChainId); + assertEq(block.chainid, notarizedChainId); + console.log("chainID", block.chainid); + } }