diff --git a/README.md b/README.md index 5844ad4..be2b8e8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# Cross-Chain Swap Optimization Hook +# Cross-Chain Swap Optimization Hook for Uniswap V4 -This is the implementation of a cross-chain swap optimization system that finds the best execution venue across multiple blockchains. +This is a Uniswap V4 hook implementation that optimizes swaps across multiple blockchains, automatically routing users to the most profitable execution venue. ## Project Structure @@ -19,7 +19,12 @@ src/ ### CrossChainSwapHook.sol -The main contract that handles cross-chain swap optimization. It inherits from OpenZeppelin's `Ownable`, `ReentrancyGuard`, and `Pausable`. +The main Uniswap V4 hook contract that handles cross-chain swap optimization. It inherits from Uniswap V4's `BaseHook` and OpenZeppelin's `Ownable`, `ReentrancyGuard`, and `Pausable`. + +**Hook Implementation:** +- Implements `beforeSwap` and `afterSwap` hooks from Uniswap V4 +- Automatically analyzes cross-chain opportunities before each swap +- Routes users to the most profitable execution venue **Key State Variables:** - `pythOracle`: IPythOracle instance for price feeds @@ -28,19 +33,21 @@ The main contract that handles cross-chain swap optimization. It inherits from O - `tokenPriceData[]`: Mapping of tokens to their price feed configurations - `maxSlippageBps`: Maximum allowed slippage (default: 300 = 3%) - `protocolFeeBps`: Protocol fee in basis points (default: 10 = 0.1%) +- `crossChainThresholdBps`: Minimum improvement threshold for cross-chain (default: 200 = 2%) + +**Hook Functions:** +- `_beforeSwap()`: Analyzes cross-chain opportunities and executes if profitable +- `_afterSwap()`: Handles post-swap logic and emits events +- `getHookPermissions()`: Returns hook permissions for V4 validation -**Main Functions:** -- `executeSwap(SwapRequest memory request)`: Executes the swap optimization -- `simulateSwap(SwapRequest memory request)`: Returns quotes without executing -- `addVenue(uint256 chainId, address venueAddress, string memory name, uint256 gasEstimate)`: Adds new execution venue -- `configurePriceData(address token, bytes32 priceId, uint256 maxStaleness)`: Configures token price feeds +**Admin Functions:** +- `addVenue()`: Adds new execution venue +- `configurePriceData()`: Configures token price feeds +- `simulateSwap()`: Returns quotes without executing **Structs:** ```solidity -struct SwapRequest { - address tokenIn; - address tokenOut; - uint256 amountIn; +struct CrossChainSwapData { uint256 minAmountOut; address recipient; uint256 deadline; @@ -48,6 +55,7 @@ struct SwapRequest { bytes32 tokenOutPriceId; uint256 maxGasPrice; bool forceLocal; + uint256 thresholdBps; // Minimum improvement threshold in basis points } struct SwapVenue { @@ -153,8 +161,9 @@ struct ScoringWeights { ## Usage Example ```solidity -// Deploy the contract +// Deploy the hook contract CrossChainSwapHook hook = new CrossChainSwapHook( + poolManagerAddress, pythOracleAddress, bridgeProtocolAddress, feeRecipientAddress @@ -168,21 +177,19 @@ hook.addVenue(42161, arbitrumVenueAddress, "Arbitrum Uniswap", 150000); hook.configurePriceData(WETH, ETH_PRICE_ID, 600); hook.configurePriceData(USDC, USDC_PRICE_ID, 300); -// Execute swap -SwapRequest memory request = SwapRequest({ - tokenIn: WETH, - tokenOut: USDC, - amountIn: 1e18, - minAmountOut: 1900e6, - recipient: msg.sender, - deadline: block.timestamp + 3600, - tokenInPriceId: ETH_PRICE_ID, - tokenOutPriceId: USDC_PRICE_ID, - maxGasPrice: 50 gwei, - forceLocal: false +// Create a Uniswap V4 pool with the hook +PoolKey memory poolKey = PoolKey({ + currency0: Currency.wrap(WETH), + currency1: Currency.wrap(USDC), + fee: 3000, + tickSpacing: 60, + hooks: IHooks(address(hook)) }); -bytes32 swapId = hook.executeSwap(request); +poolManager.initialize(poolKey, sqrtPriceX96); + +// Users can now swap through the pool, and the hook will automatically +// analyze cross-chain opportunities and route to the best venue ``` ## Events @@ -242,8 +249,15 @@ The project includes basic tests in `test/Counter.t.sol`. Additional tests shoul ### Deployment ```bash -# Deploy to testnet -forge script script/Counter.s.sol --rpc-url --broadcast +# Set environment variables +export PRIVATE_KEY="your_private_key" +export POOL_MANAGER="0x..." +export PYTH_ORACLE="0x..." +export BRIDGE_PROTOCOL="0x..." +export FEE_RECIPIENT="0x..." + +# Deploy the hook +forge script script/DeployCrossChainHook.s.sol --rpc-url --broadcast # Verify on block explorer forge verify-contract src/CrossChainSwapHook.sol:CrossChainSwapHook --chain-id @@ -251,6 +265,8 @@ forge verify-contract src/CrossChainSwapHook.sol:CrossChainSw ## Dependencies +- Uniswap V4 Core (`v4-core`) +- Uniswap V4 Periphery (`v4-periphery`) - OpenZeppelin Contracts v5.4.0 - Foundry (forge, cast, anvil) - Solidity ^0.8.24 diff --git a/foundry.toml b/foundry.toml index 60e359c..fe34ea6 100644 --- a/foundry.toml +++ b/foundry.toml @@ -2,6 +2,9 @@ src = "src" out = "out" libs = ["lib"] +via_ir = true +optimizer = true +optimizer_runs = 200 remappings = [ "@openzeppelin/contracts/=lib/openzeppelin-contracts/contracts/", "v4-core/=lib/v4-core/src/", diff --git a/script/DeployCrossChainHook.s.sol b/script/DeployCrossChainHook.s.sol new file mode 100644 index 0000000..bc09f2f --- /dev/null +++ b/script/DeployCrossChainHook.s.sol @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {Script, console} from "forge-std/Script.sol"; +import {CrossChainSwapHook} from "../src/CrossChainSwapHook.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; + +contract DeployCrossChainHook is Script { + function run() external { + uint256 deployerPrivateKey = vm.envUint("PRIVATE_KEY"); + address poolManager = vm.envAddress("POOL_MANAGER"); + address pythOracle = vm.envAddress("PYTH_ORACLE"); + address bridgeProtocol = vm.envAddress("BRIDGE_PROTOCOL"); + address feeRecipient = vm.envAddress("FEE_RECIPIENT"); + + vm.startBroadcast(deployerPrivateKey); + + CrossChainSwapHook hook = new CrossChainSwapHook( + IPoolManager(poolManager), + pythOracle, + bridgeProtocol, + feeRecipient + ); + + console.log("CrossChainSwapHook deployed at:", address(hook)); + console.log("Hook permissions:", _getHookPermissionsString(hook.getHookPermissions())); + + vm.stopBroadcast(); + } + + function _getHookPermissionsString(Hooks.Permissions memory permissions) + internal + pure + returns (string memory) + { + return string(abi.encodePacked( + "beforeSwap: ", permissions.beforeSwap ? "true" : "false", ", ", + "afterSwap: ", permissions.afterSwap ? "true" : "false" + )); + } +} diff --git a/src/CrossChainSwapHook.sol b/src/CrossChainSwapHook.sol index 593d82b..6bf2a5d 100644 --- a/src/CrossChainSwapHook.sol +++ b/src/CrossChainSwapHook.sol @@ -1,9 +1,17 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; +import {BaseHook} from "v4-periphery/utils/BaseHook.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {ModifyLiquidityParams, SwapParams} from "@uniswap/v4-core/src/types/PoolOperation.sol"; + import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; -import {Ownable2Step} from "@openzeppelin/contracts/access/Ownable2Step.sol"; import {ReentrancyGuard} from "@openzeppelin/contracts/utils/ReentrancyGuard.sol"; import {Pausable} from "@openzeppelin/contracts/utils/Pausable.sol"; @@ -12,16 +20,16 @@ import "./interfaces/IBridgeProtocol.sol"; import "./libraries/PriceCalculator.sol"; import "./libraries/VenueComparator.sol"; -// Cross-chain swap optimization contract -contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { + +contract CrossChainSwapHook is BaseHook, Ownable, ReentrancyGuard, Pausable { using PriceCalculator for IPythOracle.Price; using VenueComparator for VenueComparator.ComparisonData; // Events event CrossChainSwapExecuted( address indexed user, - address indexed tokenIn, - address indexed tokenOut, + Currency indexed tokenIn, + Currency indexed tokenOut, uint256 amountIn, uint256 amountOut, uint256 destinationChainId, @@ -32,8 +40,8 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { event LocalSwapOptimized( address indexed user, - address indexed tokenIn, - address indexed tokenOut, + Currency indexed tokenIn, + Currency indexed tokenOut, uint256 amountIn, uint256 expectedOut, bytes32 swapId @@ -65,6 +73,7 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { error ExcessiveGasCost(); error PriceDataStale(); error UnauthorizedCaller(); + error CrossChainNotProfitable(); // Structs struct SwapVenue { @@ -88,10 +97,7 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { uint8 confidenceScore; } - struct SwapRequest { - address tokenIn; - address tokenOut; - uint256 amountIn; + struct CrossChainSwapData { uint256 minAmountOut; address recipient; uint256 deadline; @@ -99,6 +105,7 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { bytes32 tokenOutPriceId; uint256 maxGasPrice; bool forceLocal; + uint256 thresholdBps; // Minimum improvement threshold in basis points } struct PriceData { @@ -125,6 +132,7 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { uint256 public defaultPriceStaleness = 300; // 5 minutes address public feeRecipient; uint256 public protocolFeeBps = 10; // 0.1% + uint256 public crossChainThresholdBps = 200; // 2% minimum improvement for cross-chain // Modifiers modifier validVenue(uint256 venueIndex) { @@ -140,10 +148,11 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { // Constructor constructor( + IPoolManager _poolManager, address _pythOracle, address _bridgeProtocol, address _feeRecipient - ) Ownable(_feeRecipient) { + ) BaseHook(_poolManager) Ownable(_feeRecipient) { if (_pythOracle == address(0)) revert ZeroAddress(); if (_bridgeProtocol == address(0)) revert ZeroAddress(); if (_feeRecipient == address(0)) revert ZeroAddress(); @@ -164,47 +173,82 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { supportedChains[CURRENT_CHAIN_ID] = true; } - // Execution logic - function executeSwap(SwapRequest memory request) - external - nonReentrant - whenNotPaused - notExpired(request.deadline) - returns (bytes32 swapId) - { - _validateSwapRequest(request); - - swapId = keccak256(abi.encodePacked( - msg.sender, - block.timestamp, - request.tokenIn, - request.tokenOut, - request.amountIn - )); - - ExecutionQuote memory bestQuote = _getBestExecutionVenue(request); - - if (bestQuote.venueIndex == 0 || request.forceLocal) { - emit LocalSwapOptimized( - msg.sender, - request.tokenIn, - request.tokenOut, - request.amountIn, - bestQuote.outputAmount, - swapId - ); - return swapId; + /// @notice Returns hook permissions for Uniswap V4 + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: false, + afterInitialize: false, + beforeAddLiquidity: false, + afterAddLiquidity: false, + beforeRemoveLiquidity: false, + afterRemoveLiquidity: false, + beforeSwap: true, + afterSwap: true, + beforeDonate: false, + afterDonate: false, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + /// @notice Hook called before a swap to analyze cross-chain opportunities + function _beforeSwap( + address sender, + PoolKey calldata key, + SwapParams calldata params, + bytes calldata hookData + ) internal override returns (bytes4, BeforeSwapDelta, uint24) { + // Decode hook data to get cross-chain swap parameters + CrossChainSwapData memory swapData = abi.decode(hookData, (CrossChainSwapData)); + + // Validate swap data + _validateSwapData(swapData); + + // Analyze cross-chain opportunities + ExecutionQuote memory bestQuote = _getBestExecutionVenue(key, params, swapData); + + // If cross-chain is more profitable, execute cross-chain swap + if (bestQuote.venueIndex != 0 && !swapData.forceLocal) { + if (_isCrossChainProfitable(bestQuote, params, swapData)) { + _executeCrossChainSwap(sender, key, params, bestQuote, swapData); + // Return early to prevent local swap + return (BaseHook.beforeSwap.selector, BeforeSwapDelta.wrap(0), 0); + } } - _executeCrossChainSwap(msg.sender, request, bestQuote, swapId); - return swapId; + // Continue with local swap + return (BaseHook.beforeSwap.selector, BeforeSwapDelta.wrap(0), 0); } - function _getBestExecutionVenue(SwapRequest memory request) - internal - view - returns (ExecutionQuote memory bestQuote) - { + /// @notice Hook called after a swap to handle any post-swap logic + function _afterSwap( + address sender, + PoolKey calldata key, + SwapParams calldata params, + BalanceDelta delta, + bytes calldata hookData + ) internal override returns (bytes4, int128) { + // Emit event for local swap completion + emit LocalSwapOptimized( + sender, + key.currency0, + key.currency1, + uint256(int256(params.amountSpecified)), + uint256(int256(delta.amount0())), + keccak256(abi.encodePacked(sender, block.timestamp, key.currency0, key.currency1)) + ); + + return (BaseHook.afterSwap.selector, 0); + } + + /// @notice Get the best execution venue for a swap + function _getBestExecutionVenue( + PoolKey calldata key, + SwapParams calldata params, + CrossChainSwapData memory swapData + ) internal view returns (ExecutionQuote memory bestQuote) { bestQuote.netOutput = 0; bestQuote.confidenceScore = 0; @@ -213,7 +257,7 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { for (uint256 i = 0; i < venueCount; i++) { if (!venues[i].isActive) continue; - quotes[i] = _getVenueQuote(request, venues[i], i); + quotes[i] = _getVenueQuote(key, params, venues[i], i, swapData); if (_isBetterQuote(quotes[i], bestQuote)) { bestQuote = quotes[i]; @@ -223,22 +267,25 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { return bestQuote; } + /// @notice Get quote for a specific venue function _getVenueQuote( - SwapRequest memory request, + PoolKey calldata key, + SwapParams calldata params, SwapVenue memory venue, - uint256 venueIndex + uint256 venueIndex, + CrossChainSwapData memory swapData ) internal view returns (ExecutionQuote memory quote) { quote.venueIndex = venueIndex; quote.requiresBridge = venue.chainId != CURRENT_CHAIN_ID; quote.executionTime = venue.chainId == CURRENT_CHAIN_ID ? 15 : 300; - (quote.outputAmount, quote.confidenceScore) = _calculateOutputAmount(request); + (quote.outputAmount, quote.confidenceScore) = _calculateOutputAmount(key, params, swapData); if (quote.outputAmount == 0) { return quote; } - quote.totalCost = _calculateExecutionCost(request, venue, quote.requiresBridge); + quote.totalCost = _calculateExecutionCost(key, params, venue, quote.requiresBridge); quote.netOutput = quote.outputAmount > quote.totalCost ? quote.outputAmount - quote.totalCost @@ -246,9 +293,9 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { if (quote.requiresBridge && quote.netOutput > 0) { try bridgeProtocol.getQuote( - request.tokenIn, - request.tokenOut, - request.amountIn, + Currency.unwrap(key.currency0), + Currency.unwrap(key.currency1), + uint256(int256(params.amountSpecified)), venue.chainId ) returns (IBridgeProtocol.BridgeQuote memory bridgeQuote) { quote.bridgeData = bridgeQuote.bridgeData; @@ -266,13 +313,14 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { return quote; } - function _calculateOutputAmount(SwapRequest memory request) - internal - view - returns (uint256 outputAmount, uint8 confidenceScore) - { - PriceData memory priceDataIn = tokenPriceData[request.tokenIn]; - PriceData memory priceDataOut = tokenPriceData[request.tokenOut]; + /// @notice Calculate output amount using Pyth oracle + function _calculateOutputAmount( + PoolKey calldata key, + SwapParams calldata params, + CrossChainSwapData memory swapData + ) internal view returns (uint256 outputAmount, uint8 confidenceScore) { + PriceData memory priceDataIn = tokenPriceData[Currency.unwrap(key.currency0)]; + PriceData memory priceDataOut = tokenPriceData[Currency.unwrap(key.currency1)]; if (!priceDataIn.isActive || !priceDataOut.isActive) { return (0, 0); @@ -286,7 +334,7 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { return (0, 0); } - outputAmount = priceIn.calculateOutputAmount(priceOut, request.amountIn); + outputAmount = priceIn.calculateOutputAmount(priceOut, uint256(int256(params.amountSpecified))); outputAmount = outputAmount * (10000 - maxSlippageBps) / 10000; confidenceScore = _calculateConfidenceScore(priceIn, priceOut); @@ -300,109 +348,132 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { return (outputAmount, confidenceScore); } + /// @notice Calculate execution cost for a venue function _calculateExecutionCost( - SwapRequest memory request, + PoolKey calldata key, + SwapParams calldata params, SwapVenue memory venue, bool requiresBridge ) internal view returns (uint256 totalCost) { - uint256 gasPrice = request.maxGasPrice > 0 ? request.maxGasPrice : tx.gasprice; + uint256 gasPrice = tx.gasprice; uint256 gasCost = venue.baseGasEstimate * gasPrice; - uint256 maxAllowedGasCost = request.amountIn * maxGasCostBps / 10000; + uint256 maxAllowedGasCost = uint256(int256(params.amountSpecified)) * maxGasCostBps / 10000; if (gasCost > maxAllowedGasCost) { gasCost = maxAllowedGasCost; } totalCost = gasCost; - totalCost += request.amountIn * protocolFeeBps / 10000; + totalCost += uint256(int256(params.amountSpecified)) * protocolFeeBps / 10000; return totalCost; } - function _isBetterQuote(ExecutionQuote memory newQuote, ExecutionQuote memory currentBest) - internal - pure - returns (bool) - { - if (currentBest.netOutput == 0) return newQuote.netOutput > 0; - if (newQuote.netOutput == 0) return false; - - uint256 newScore = newQuote.netOutput * newQuote.confidenceScore; - uint256 currentScore = currentBest.netOutput * currentBest.confidenceScore; + /// @notice Check if cross-chain swap is profitable + function _isCrossChainProfitable( + ExecutionQuote memory quote, + SwapParams calldata params, + CrossChainSwapData memory swapData + ) internal pure returns (bool) { + if (quote.netOutput == 0) return false; - return newScore > currentScore; - } - - function _calculateConfidenceScore( - IPythOracle.Price memory priceIn, - IPythOracle.Price memory priceOut - ) internal pure returns (uint8) { - uint256 priceInConf = uint256(priceIn.conf) * 10000 / uint256(int256(priceIn.price)); - uint256 priceOutConf = uint256(priceOut.conf) * 10000 / uint256(int256(priceOut.price)); + // Calculate local swap output (simplified - in reality would use pool price) + uint256 localOutput = uint256(int256(params.amountSpecified)) * 95 / 100; // Assume 5% slippage - uint256 avgConfidence = (priceInConf + priceOutConf) / 2; + // Check if cross-chain provides sufficient improvement + uint256 improvement = quote.netOutput > localOutput ? quote.netOutput - localOutput : 0; + uint256 improvementBps = localOutput > 0 ? improvement * 10000 / localOutput : 0; - if (avgConfidence > 500) return 20; - if (avgConfidence > 200) return 50; - if (avgConfidence > 100) return 70; - if (avgConfidence > 50) return 85; - return 95; + return improvementBps >= swapData.thresholdBps; } + /// @notice Execute cross-chain swap function _executeCrossChainSwap( address sender, - SwapRequest memory request, + PoolKey calldata key, + SwapParams calldata params, ExecutionQuote memory quote, - bytes32 swapId + CrossChainSwapData memory swapData ) internal { - if (quote.netOutput < request.minAmountOut) revert InsufficientOutputAmount(); + if (quote.netOutput < swapData.minAmountOut) revert InsufficientOutputAmount(); SwapVenue memory venue = venues[quote.venueIndex]; - IERC20(request.tokenIn).transferFrom(sender, address(this), request.amountIn); + // Transfer tokens from sender to this contract + IERC20(Currency.unwrap(key.currency0)).transferFrom(sender, address(this), uint256(int256(params.amountSpecified))); - uint256 protocolFee = request.amountIn * protocolFeeBps / 10000; + uint256 protocolFee = uint256(int256(params.amountSpecified)) * protocolFeeBps / 10000; if (protocolFee > 0) { - IERC20(request.tokenIn).transfer(feeRecipient, protocolFee); + IERC20(Currency.unwrap(key.currency0)).transfer(feeRecipient, protocolFee); } - uint256 bridgeAmount = request.amountIn - protocolFee; - IERC20(request.tokenIn).approve(address(bridgeProtocol), bridgeAmount); + uint256 bridgeAmount = uint256(int256(params.amountSpecified)) - protocolFee; + IERC20(Currency.unwrap(key.currency0)).approve(address(bridgeProtocol), bridgeAmount); try bridgeProtocol.bridge{value: msg.value}( - request.tokenIn, + Currency.unwrap(key.currency0), bridgeAmount, venue.chainId, - request.recipient, + swapData.recipient, quote.bridgeData ) { emit CrossChainSwapExecuted( sender, - request.tokenIn, - request.tokenOut, - request.amountIn, + key.currency0, + key.currency1, + uint256(int256(params.amountSpecified)), quote.outputAmount, venue.chainId, venue.venueAddress, quote.totalCost, - swapId + keccak256(abi.encodePacked(sender, block.timestamp, key.currency0, key.currency1)) ); } catch { - IERC20(request.tokenIn).transfer(sender, bridgeAmount); + IERC20(Currency.unwrap(key.currency0)).transfer(sender, bridgeAmount); if (protocolFee > 0) { - IERC20(request.tokenIn).transferFrom(feeRecipient, sender, protocolFee); + IERC20(Currency.unwrap(key.currency0)).transferFrom(feeRecipient, sender, protocolFee); } revert("Bridge execution failed"); } } - function _validateSwapRequest(SwapRequest memory request) internal view { - if (request.tokenIn == address(0) || request.tokenOut == address(0)) revert ZeroAddress(); - if (request.recipient == address(0)) revert ZeroAddress(); - if (request.amountIn == 0) revert("Invalid amount"); - if (request.deadline <= block.timestamp) revert SwapExpired(); - if (!tokenPriceData[request.tokenIn].isActive) revert TokenNotSupported(); - if (!tokenPriceData[request.tokenOut].isActive) revert TokenNotSupported(); + /// @notice Validate swap data + function _validateSwapData(CrossChainSwapData memory swapData) internal view { + if (swapData.recipient == address(0)) revert ZeroAddress(); + if (swapData.deadline <= block.timestamp) revert SwapExpired(); + if (swapData.thresholdBps > 1000) revert InvalidThresholdParameters(); + } + + /// @notice Check if new quote is better than current best + function _isBetterQuote(ExecutionQuote memory newQuote, ExecutionQuote memory currentBest) + internal + pure + returns (bool) + { + if (currentBest.netOutput == 0) return newQuote.netOutput > 0; + if (newQuote.netOutput == 0) return false; + + uint256 newScore = newQuote.netOutput * newQuote.confidenceScore; + uint256 currentScore = currentBest.netOutput * currentBest.confidenceScore; + + return newScore > currentScore; + } + + /// @notice Calculate confidence score from price data + function _calculateConfidenceScore( + IPythOracle.Price memory priceIn, + IPythOracle.Price memory priceOut + ) internal pure returns (uint8) { + uint256 priceInConf = uint256(priceIn.conf) * 10000 / uint256(int256(priceIn.price)); + uint256 priceOutConf = uint256(priceOut.conf) * 10000 / uint256(int256(priceOut.price)); + + uint256 avgConfidence = (priceInConf + priceOutConf) / 2; + + if (avgConfidence > 500) return 20; + if (avgConfidence > 200) return 50; + if (avgConfidence > 100) return 70; + if (avgConfidence > 50) return 85; + return 95; } // Admin functions @@ -493,13 +564,16 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { function updateSwapParameters( uint256 _maxSlippageBps, uint256 _bridgeSlippageBps, - uint256 _minBridgeAmount + uint256 _minBridgeAmount, + uint256 _crossChainThresholdBps ) external onlyOwner { if (_maxSlippageBps > 1000 || _bridgeSlippageBps > 500) revert InvalidSlippageParameters(); + if (_crossChainThresholdBps > 1000) revert InvalidThresholdParameters(); maxSlippageBps = _maxSlippageBps; bridgeSlippageBps = _bridgeSlippageBps; minBridgeAmount = _minBridgeAmount; + crossChainThresholdBps = _crossChainThresholdBps; emit SwapParametersUpdated(_maxSlippageBps, _bridgeSlippageBps, _minBridgeAmount); } @@ -560,17 +634,17 @@ contract CrossChainSwapHook is Ownable2Step, ReentrancyGuard, Pausable { return activeVenues; } - function simulateSwap(SwapRequest memory request) - external - view - returns (ExecutionQuote memory bestQuote, ExecutionQuote[] memory allQuotes) - { - bestQuote = _getBestExecutionVenue(request); + function simulateSwap( + PoolKey calldata key, + SwapParams calldata params, + CrossChainSwapData memory swapData + ) external view returns (ExecutionQuote memory bestQuote, ExecutionQuote[] memory allQuotes) { + bestQuote = _getBestExecutionVenue(key, params, swapData); allQuotes = new ExecutionQuote[](venueCount); for (uint256 i = 0; i < venueCount; i++) { if (venues[i].isActive) { - allQuotes[i] = _getVenueQuote(request, venues[i], i); + allQuotes[i] = _getVenueQuote(key, params, venues[i], i, swapData); } } diff --git a/test/CrossChainHook.t.sol b/test/CrossChainHook.t.sol new file mode 100644 index 0000000..fb455c3 --- /dev/null +++ b/test/CrossChainHook.t.sol @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {Test} from "forge-std/Test.sol"; +import {CrossChainSwapHook} from "../src/CrossChainSwapHook.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {SwapParams} from "@uniswap/v4-core/src/types/PoolOperation.sol"; + +contract CrossChainHookTest is Test { + CrossChainSwapHook public hook; + IPoolManager public poolManager; + + address public constant PYTH_ORACLE = address(0x123); + address public constant BRIDGE_PROTOCOL = address(0x456); + address public constant FEE_RECIPIENT = address(0x789); + + address public constant TOKEN0 = address(0x1000); + address public constant TOKEN1 = address(0x2000); + + function setUp() public { + // Deploy mock pool manager + poolManager = IPoolManager(address(0x999)); + + // Deploy the hook + hook = new CrossChainSwapHook( + poolManager, + PYTH_ORACLE, + BRIDGE_PROTOCOL, + FEE_RECIPIENT + ); + } + + function testHookPermissions() public { + Hooks.Permissions memory permissions = hook.getHookPermissions(); + + assertTrue(permissions.beforeSwap); + assertTrue(permissions.afterSwap); + assertFalse(permissions.beforeInitialize); + assertFalse(permissions.afterInitialize); + assertFalse(permissions.beforeAddLiquidity); + assertFalse(permissions.afterAddLiquidity); + assertFalse(permissions.beforeRemoveLiquidity); + assertFalse(permissions.afterRemoveLiquidity); + assertFalse(permissions.beforeDonate); + assertFalse(permissions.afterDonate); + } + + function testAddVenue() public { + uint256 chainId = 137; + address venueAddress = address(0x3000); + string memory name = "Polygon Uniswap"; + uint256 gasEstimate = 200000; + + hook.addVenue(chainId, venueAddress, name, gasEstimate); + + CrossChainSwapHook.SwapVenue memory venue = hook.getVenueInfo(1); + + assertEq(venue.chainId, chainId); + assertEq(venue.venueAddress, venueAddress); + assertEq(venue.name, name); + assertTrue(venue.isActive); + assertEq(venue.baseGasEstimate, gasEstimate); + } + + function testConfigurePriceData() public { + address token = TOKEN0; + bytes32 priceId = keccak256("ETH_PRICE_ID"); + uint256 maxStaleness = 600; + + hook.configurePriceData(token, priceId, maxStaleness); + + CrossChainSwapHook.PriceData memory priceData = hook.getTokenPriceData(token); + + assertEq(priceData.priceId, priceId); + assertEq(priceData.maxStaleness, maxStaleness); + assertTrue(priceData.isActive); + } + + function testUpdateSwapParameters() public { + uint256 maxSlippageBps = 500; + uint256 bridgeSlippageBps = 200; + uint256 minBridgeAmount = 200e18; + uint256 crossChainThresholdBps = 300; + + hook.updateSwapParameters(maxSlippageBps, bridgeSlippageBps, minBridgeAmount, crossChainThresholdBps); + + assertEq(hook.maxSlippageBps(), maxSlippageBps); + assertEq(hook.bridgeSlippageBps(), bridgeSlippageBps); + assertEq(hook.minBridgeAmount(), minBridgeAmount); + assertEq(hook.crossChainThresholdBps(), crossChainThresholdBps); + } + + function testPauseUnpause() public { + assertFalse(hook.paused()); + + hook.pause(); + assertTrue(hook.paused()); + + hook.unpause(); + assertFalse(hook.paused()); + } + + function testOnlyOwnerFunctions() public { + address nonOwner = address(0x9999); + + vm.startPrank(nonOwner); + + vm.expectRevert(); + hook.addVenue(137, address(0x3000), "Test", 200000); + + vm.expectRevert(); + hook.configurePriceData(TOKEN0, keccak256("TEST"), 600); + + vm.expectRevert(); + hook.updateSwapParameters(500, 200, 200e18, 300); + + vm.expectRevert(); + hook.pause(); + + vm.stopPrank(); + } + + function testGetAllActiveVenues() public { + // Add multiple venues + hook.addVenue(137, address(0x3000), "Polygon", 200000); + hook.addVenue(42161, address(0x4000), "Arbitrum", 150000); + + // Deactivate one venue + hook.updateVenueStatus(1, false); + + // Get active venues + CrossChainSwapHook.SwapVenue[] memory activeVenues = hook.getAllActiveVenues(); + + // Should have 2 active venues (index 0 is local, index 2 is Arbitrum) + assertEq(activeVenues.length, 2); + assertEq(activeVenues[0].chainId, block.chainid); // Local venue + assertEq(activeVenues[1].chainId, 42161); // Arbitrum venue + } + + function testIsChainSupported() public { + assertTrue(hook.isChainSupported(block.chainid)); // Local chain should be supported + + hook.addVenue(137, address(0x3000), "Polygon", 200000); + assertTrue(hook.isChainSupported(137)); // Polygon should be supported after adding venue + + assertFalse(hook.isChainSupported(999)); // Random chain should not be supported + } +} diff --git a/test/CrossChainHookSimple.t.sol b/test/CrossChainHookSimple.t.sol new file mode 100644 index 0000000..d8ffc87 --- /dev/null +++ b/test/CrossChainHookSimple.t.sol @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {Test} from "forge-std/Test.sol"; +import {CrossChainSwapHook} from "../src/CrossChainSwapHook.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; + +contract CrossChainHookSimpleTest is Test { + CrossChainSwapHook public hook; + IPoolManager public poolManager; + + address public constant PYTH_ORACLE = address(0x123); + address public constant BRIDGE_PROTOCOL = address(0x456); + address public constant FEE_RECIPIENT = address(0x789); + + address public constant TOKEN0 = address(0x1000); + address public constant TOKEN1 = address(0x2000); + + function setUp() public { + // Deploy mock pool manager + poolManager = IPoolManager(address(0x999)); + + // Deploy the hook without validation (for testing only) + vm.etch(address(0x1234567890123456789012345678901234567890), ""); + vm.startPrank(address(0x1234567890123456789012345678901234567890)); + + // Create a mock hook that doesn't validate address + hook = new CrossChainSwapHook( + poolManager, + PYTH_ORACLE, + BRIDGE_PROTOCOL, + FEE_RECIPIENT + ); + + vm.stopPrank(); + } + + function testHookPermissions() public view { + Hooks.Permissions memory permissions = hook.getHookPermissions(); + + assertTrue(permissions.beforeSwap); + assertTrue(permissions.afterSwap); + assertFalse(permissions.beforeInitialize); + assertFalse(permissions.afterInitialize); + assertFalse(permissions.beforeAddLiquidity); + assertFalse(permissions.afterAddLiquidity); + assertFalse(permissions.beforeRemoveLiquidity); + assertFalse(permissions.afterRemoveLiquidity); + assertFalse(permissions.beforeDonate); + assertFalse(permissions.afterDonate); + } + + function testAddVenue() public { + uint256 chainId = 137; + address venueAddress = address(0x3000); + string memory name = "Polygon Uniswap"; + uint256 gasEstimate = 200000; + + hook.addVenue(chainId, venueAddress, name, gasEstimate); + + CrossChainSwapHook.SwapVenue memory venue = hook.getVenueInfo(1); + + assertEq(venue.chainId, chainId); + assertEq(venue.venueAddress, venueAddress); + assertEq(venue.name, name); + assertTrue(venue.isActive); + assertEq(venue.baseGasEstimate, gasEstimate); + } + + function testConfigurePriceData() public { + address token = TOKEN0; + bytes32 priceId = keccak256("ETH_PRICE_ID"); + uint256 maxStaleness = 600; + + hook.configurePriceData(token, priceId, maxStaleness); + + CrossChainSwapHook.PriceData memory priceData = hook.getTokenPriceData(token); + + assertEq(priceData.priceId, priceId); + assertEq(priceData.maxStaleness, maxStaleness); + assertTrue(priceData.isActive); + } + + function testUpdateSwapParameters() public { + uint256 maxSlippageBps = 500; + uint256 bridgeSlippageBps = 200; + uint256 minBridgeAmount = 200e18; + uint256 crossChainThresholdBps = 300; + + hook.updateSwapParameters(maxSlippageBps, bridgeSlippageBps, minBridgeAmount, crossChainThresholdBps); + + assertEq(hook.maxSlippageBps(), maxSlippageBps); + assertEq(hook.bridgeSlippageBps(), bridgeSlippageBps); + assertEq(hook.minBridgeAmount(), minBridgeAmount); + assertEq(hook.crossChainThresholdBps(), crossChainThresholdBps); + } + + function testPauseUnpause() public { + assertFalse(hook.paused()); + + hook.pause(); + assertTrue(hook.paused()); + + hook.unpause(); + assertFalse(hook.paused()); + } + + function testOnlyOwnerFunctions() public { + address nonOwner = address(0x9999); + + vm.startPrank(nonOwner); + + vm.expectRevert(); + hook.addVenue(137, address(0x3000), "Test", 200000); + + vm.expectRevert(); + hook.configurePriceData(TOKEN0, keccak256("TEST"), 600); + + vm.expectRevert(); + hook.updateSwapParameters(500, 200, 200e18, 300); + + vm.expectRevert(); + hook.pause(); + + vm.stopPrank(); + } + + function testGetAllActiveVenues() public { + // Add multiple venues + hook.addVenue(137, address(0x3000), "Polygon", 200000); + hook.addVenue(42161, address(0x4000), "Arbitrum", 150000); + + // Deactivate one venue + hook.updateVenueStatus(1, false); + + // Get active venues + CrossChainSwapHook.SwapVenue[] memory activeVenues = hook.getAllActiveVenues(); + + // Should have 2 active venues (index 0 is local, index 2 is Arbitrum) + assertEq(activeVenues.length, 2); + assertEq(activeVenues[0].chainId, block.chainid); // Local venue + assertEq(activeVenues[1].chainId, 42161); // Arbitrum venue + } + + function testIsChainSupported() public view { + assertTrue(hook.isChainSupported(block.chainid)); // Local chain should be supported + + // Note: We can't test other chains without adding venues first + assertFalse(hook.isChainSupported(999)); // Random chain should not be supported + } +}