From 7caff7bc06023086afc8c0d583d34ec5b213eb5b Mon Sep 17 00:00:00 2001 From: Omri Dan <61094771+omridan159@users.noreply.github.com> Date: Tue, 4 Jun 2024 20:54:03 +0300 Subject: [PATCH] fix: display the DApp URL in connect screen for MetaMask IOS-SDK (#9755) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## **Description** - Display the DApp URL in connect screen for MetaMask IOS-SDK. - Fixes the dapp icon on the connection screen (https://github.com/MetaMask/metamask-mobile/issues/9834) ## **Related issues** Fixes: ## **Manual testing steps** 1. Go to this page... 2. 3. ## **Screenshots/Recordings** ### **Before** ### **After** https://github.com/MetaMask/metamask-mobile/assets/61094771/4da53e62-d80f-4b09-957b-a19f2775e824 ## **Pre-merge author checklist** - [x] I’ve followed [MetaMask Coding Standards](https://github.com/MetaMask/metamask-mobile/blob/main/.github/guidelines/CODING_GUIDELINES.md). - [x] I've completed the PR template to the best of my ability - [x] I’ve included tests if applicable - [x] I’ve documented my code using [JSDoc](https://jsdoc.app/) format if applicable - [x] I’ve applied the right labels on the PR (see [labeling guidelines](https://github.com/MetaMask/metamask-mobile/blob/main/.github/guidelines/LABELING_GUIDELINES.md)). Not required for external contributors. ## **Pre-merge reviewer checklist** - [ ] I've manually tested the PR (e.g. pull and build branch, run the app, test code being changed). - [ ] I confirm that this PR addresses all acceptance criteria described in the ticket it closes and includes the necessary testing evidence such as recordings and or screenshots. --------- Co-authored-by: Christopher Ferreira <104831203+christopherferreira9@users.noreply.github.com> --- .../Views/AccountConnect/AccountConnect.tsx | 42 ++- .../hooks/useMetrics/useMetrics.types.ts | 7 + .../ParseManager/extractURLParams.ts | 1 + .../handleMetaMaskDeeplink.test.ts | 133 +++++++ .../ParseManager/handleMetaMaskDeeplink.ts | 3 +- .../SDKConnect/AndroidSDK/AndroidService.ts | 12 +- .../DeeplinkProtocolService.test.ts | 355 ++++++++++++++++++ .../DeeplinkProtocolService.ts | 158 +++++--- 8 files changed, 634 insertions(+), 77 deletions(-) create mode 100644 app/core/SDKConnect/SDKDeeplinkProtocol/DeeplinkProtocolService.test.ts diff --git a/app/components/Views/AccountConnect/AccountConnect.tsx b/app/components/Views/AccountConnect/AccountConnect.tsx index f4b6499179c..46abc06b77c 100644 --- a/app/components/Views/AccountConnect/AccountConnect.tsx +++ b/app/components/Views/AccountConnect/AccountConnect.tsx @@ -65,6 +65,8 @@ import { import AccountConnectMultiSelector from './AccountConnectMultiSelector'; import AccountConnectSingle from './AccountConnectSingle'; import AccountConnectSingleSelector from './AccountConnectSingleSelector'; +import { SourceType } from '../../hooks/useMetrics/useMetrics.types'; + const createStyles = () => StyleSheet.create({ fullScreenModal: { @@ -107,24 +109,28 @@ const AccountConnect = (props: AccountConnectProps) => { : AvatarAccountType.JazzIcon, ); - // on inappBrowser: hostname - // on walletConnect: hostname - // on sdk or walletconnect + // origin is set to the last active tab url in the browser which can conflict with sdk + const inappBrowserOrigin: string = useSelector(getActiveTabUrl, isEqual); + const accountsLength = useSelector(selectAccountsLength); + + // TODO: pending transaction controller update, we need to have a parameter that can be extracted from the metadata to know the correct source (inappbrowser, walletconnect, sdk) + // on inappBrowser: hostname from inappBrowserOrigin + // on walletConnect: hostname from hostInfo + // on sdk: channelId const { origin: channelIdOrHostname } = hostInfo.metadata as { id: string; origin: string; }; - const origin: string = useSelector(getActiveTabUrl, isEqual); - const accountsLength = useSelector(selectAccountsLength); - const sdkConnection = SDKConnect.getInstance().getConnection({ channelId: channelIdOrHostname, }); - const hostname = - origin ?? channelIdOrHostname.indexOf('.') !== -1 + + const hostname = channelIdOrHostname + ? channelIdOrHostname.indexOf('.') !== -1 ? channelIdOrHostname - : sdkConnection?.originatorInfo?.url ?? ''; + : sdkConnection?.originatorInfo?.url ?? '' + : inappBrowserOrigin; const urlWithProtocol = prefixUrlWithProtocol(hostname); @@ -160,7 +166,7 @@ const AccountConnect = (props: AccountConnectProps) => { } }, [isAllowedUrl, dappUrl, channelIdOrHostname]); - const faviconSource = useFavicon(origin); + const faviconSource = useFavicon(inappBrowserOrigin); const actualIcon = useMemo( () => (dappIconUrl ? { uri: dappIconUrl } : faviconSource), @@ -179,13 +185,15 @@ const AccountConnect = (props: AccountConnectProps) => { // walletconnect channelId format: app.name.org // sdk channelId format: uuid // inappbrowser channelId format: app.name.org but origin is set - if (sdkConnection) { - return 'sdk'; - } else if (origin) { - return 'in-app browser'; + if (channelIdOrHostname) { + if (sdkConnection) { + return SourceType.SDK; + } + return SourceType.WALLET_CONNECT; } - return 'walletconnect'; - }, [sdkConnection, origin]); + + return SourceType.IN_APP_BROWSER; + }, [sdkConnection, channelIdOrHostname]); // Refreshes selected addresses based on the addition and removal of accounts. useEffect(() => { @@ -216,7 +224,7 @@ const AccountConnect = (props: AccountConnectProps) => { trackEvent(MetaMetricsEvents.CONNECT_REQUEST_CANCELLED, { number_of_accounts: accountsLength, - source: 'permission system', + source: SourceType.PERMISSION_SYSTEM, }); }, [ diff --git a/app/components/hooks/useMetrics/useMetrics.types.ts b/app/components/hooks/useMetrics/useMetrics.types.ts index 496ff6fda6f..24cdd2a753d 100644 --- a/app/components/hooks/useMetrics/useMetrics.types.ts +++ b/app/components/hooks/useMetrics/useMetrics.types.ts @@ -6,6 +6,13 @@ import { IMetaMetricsEvent, } from '../../../core/Analytics/MetaMetrics.types'; +export enum SourceType { + SDK = 'sdk', + WALLET_CONNECT = 'walletconnect', + IN_APP_BROWSER = 'in-app browser', + PERMISSION_SYSTEM = 'permission system', +} + export interface IUseMetricsHook { isEnabled(): boolean; enable(enable?: boolean): Promise; diff --git a/app/core/DeeplinkManager/ParseManager/extractURLParams.ts b/app/core/DeeplinkManager/ParseManager/extractURLParams.ts index 4018296e2ec..db816605f43 100644 --- a/app/core/DeeplinkManager/ParseManager/extractURLParams.ts +++ b/app/core/DeeplinkManager/ParseManager/extractURLParams.ts @@ -17,6 +17,7 @@ export interface DeeplinkUrlParams { message?: string; originatorInfo?: string; request?: string; + account?: string; // This is the format => "address@chainId" } function extractURLParams(url: string) { diff --git a/app/core/DeeplinkManager/ParseManager/handleMetaMaskDeeplink.test.ts b/app/core/DeeplinkManager/ParseManager/handleMetaMaskDeeplink.test.ts index ac5bbe3e3bf..9ffede5aee6 100644 --- a/app/core/DeeplinkManager/ParseManager/handleMetaMaskDeeplink.test.ts +++ b/app/core/DeeplinkManager/ParseManager/handleMetaMaskDeeplink.test.ts @@ -121,6 +121,139 @@ describe('handleMetaMaskProtocol', () => { }); }); + describe('when params.comm is "deeplinking"', () => { + beforeEach(() => { + url = `${PREFIXES.METAMASK}${ACTIONS.CONNECT}`; + params.comm = 'deeplinking'; + params.channelId = 'test-channel-id'; + params.pubkey = 'test-pubkey'; + params.originatorInfo = 'test-originator-info'; + params.request = 'test-request'; + }); + + it('should throw an error if params.scheme is not defined', () => { + params.scheme = undefined; + + expect(() => { + handleMetaMaskDeeplink({ + instance, + handled, + params, + url, + origin, + wcURL, + }); + }).toThrow('DeepLinkManager failed to connect - Invalid scheme'); + }); + + it('should call handleConnection if params.scheme is defined', () => { + const mockHandleConnection = jest.fn(); + mockSDKConnectGetInstance.mockImplementation(() => ({ + state: { + deeplinkingService: { + handleConnection: mockHandleConnection, + }, + }, + })); + + params.scheme = 'test-scheme'; + + handleMetaMaskDeeplink({ + instance, + handled, + params, + url, + origin, + wcURL, + }); + + expect(mockHandleConnection).toHaveBeenCalledWith({ + channelId: params.channelId, + url, + scheme: params.scheme, + dappPublicKey: params.pubkey, + originatorInfo: params.originatorInfo, + request: params.request, + }); + }); + }); + + describe('when url starts with ${PREFIXES.METAMASK}${ACTIONS.MMSDK}', () => { + beforeEach(() => { + url = `${PREFIXES.METAMASK}${ACTIONS.MMSDK}`; + params.channelId = 'test-channel-id'; + params.pubkey = 'test-pubkey'; + params.account = 'test-account'; + }); + + it('should throw an error if params.message is not defined', () => { + params.message = undefined; + + expect(() => { + handleMetaMaskDeeplink({ + instance, + handled, + params, + url, + origin, + wcURL, + }); + }).toThrow( + 'DeepLinkManager: deeplinkingService failed to handleMessage - Invalid message', + ); + }); + + it('should throw an error if params.scheme is not defined', () => { + params.message = 'test-message'; + params.scheme = undefined; + + expect(() => { + handleMetaMaskDeeplink({ + instance, + handled, + params, + url, + origin, + wcURL, + }); + }).toThrow( + 'DeepLinkManager: deeplinkingService failed to handleMessage - Invalid scheme', + ); + }); + + it('should call handleMessage if params.message and params.scheme are defined', () => { + const mockHandleMessage = jest.fn(); + mockSDKConnectGetInstance.mockImplementation(() => ({ + state: { + deeplinkingService: { + handleMessage: mockHandleMessage, + }, + }, + })); + + params.message = 'test-message'; + params.scheme = 'test-scheme'; + + handleMetaMaskDeeplink({ + instance, + handled, + params, + url, + origin, + wcURL, + }); + + expect(mockHandleMessage).toHaveBeenCalledWith({ + channelId: params.channelId, + url, + message: params.message, + dappPublicKey: params.pubkey, + scheme: params.scheme, + account: params.account ?? '@', + }); + }); + }); + describe('when url starts with ${PREFIXES.METAMASK}${ACTIONS.CONNECT}', () => { beforeEach(() => { url = `${PREFIXES.METAMASK}${ACTIONS.CONNECT}`; diff --git a/app/core/DeeplinkManager/ParseManager/handleMetaMaskDeeplink.ts b/app/core/DeeplinkManager/ParseManager/handleMetaMaskDeeplink.ts index e6dc7c4e1d7..db98211622b 100644 --- a/app/core/DeeplinkManager/ParseManager/handleMetaMaskDeeplink.ts +++ b/app/core/DeeplinkManager/ParseManager/handleMetaMaskDeeplink.ts @@ -89,14 +89,13 @@ export function handleMetaMaskDeeplink({ ); } - DevLogger.log('DeepLinkManager:: ===> params from deeplink', params); - SDKConnect.getInstance().state.deeplinkingService?.handleMessage({ channelId: params.channelId, url, message: params.message, dappPublicKey: params.pubkey, scheme: params.scheme, + account: params.account ?? '@', }); } else if ( url.startsWith(`${PREFIXES.METAMASK}${ACTIONS.WC}`) || diff --git a/app/core/SDKConnect/AndroidSDK/AndroidService.ts b/app/core/SDKConnect/AndroidSDK/AndroidService.ts index 61bf3983802..7903d4fb2d1 100644 --- a/app/core/SDKConnect/AndroidSDK/AndroidService.ts +++ b/app/core/SDKConnect/AndroidSDK/AndroidService.ts @@ -122,7 +122,7 @@ export default class AndroidService extends EventEmitter2 { } private setupOnClientsConnectedListener() { - this.eventHandler.onClientsConnected((sClientInfo: string) => { + this.eventHandler.onClientsConnected(async (sClientInfo: string) => { const clientInfo: DappClient = JSON.parse(sClientInfo); DevLogger.log(`AndroidService::clients_connected`, clientInfo); @@ -155,6 +155,15 @@ export default class AndroidService extends EventEmitter2 { return; } + await SDKConnect.getInstance().addDappConnection({ + id: clientInfo.clientId, + lastAuthorized: Date.now(), + origin: AppConstants.MM_SDK.ANDROID_SDK, + originatorInfo: clientInfo.originatorInfo, + otherPublicKey: '', + validUntil: Date.now() + DEFAULT_SESSION_TIMEOUT_MS, + }); + const handleEventAsync = async () => { const keyringController = ( Engine.context as { KeyringController: KeyringController } @@ -361,6 +370,7 @@ export default class AndroidService extends EventEmitter2 { const chainId = networkController.state.providerConfig.chainId; this.currentClientId = sessionId; + // Handle custom rpc method const processedRpc = await handleCustomRpcCalls({ batchRPCManager: this.batchRPCManager, diff --git a/app/core/SDKConnect/SDKDeeplinkProtocol/DeeplinkProtocolService.test.ts b/app/core/SDKConnect/SDKDeeplinkProtocol/DeeplinkProtocolService.test.ts new file mode 100644 index 00000000000..ff32719c5e5 --- /dev/null +++ b/app/core/SDKConnect/SDKDeeplinkProtocol/DeeplinkProtocolService.test.ts @@ -0,0 +1,355 @@ +/* eslint-disable @typescript-eslint/ban-ts-comment */ +import { Linking } from 'react-native'; +import Engine from '../../../core/Engine'; +import Logger from '../../../util/Logger'; +import BackgroundBridge from '../../BackgroundBridge/BackgroundBridge'; +import SDKConnect from '../SDKConnect'; +import handleBatchRpcResponse from '../handlers/handleBatchRpcResponse'; +import handleCustomRpcCalls from '../handlers/handleCustomRpcCalls'; +import DevLogger from '../utils/DevLogger'; +import DeeplinkProtocolService from './DeeplinkProtocolService'; + +jest.mock('../SDKConnect'); +jest.mock('../../../core/Engine'); +jest.mock('react-native'); +jest.mock('../../BackgroundBridge/BackgroundBridge'); +jest.mock('../utils/DevLogger'); +jest.mock('../../../util/Logger'); +jest.mock('../handlers/handleCustomRpcCalls'); +jest.mock('../handlers/handleBatchRpcResponse'); + +describe('DeeplinkProtocolService', () => { + let service: DeeplinkProtocolService; + + beforeEach(() => { + jest.clearAllMocks(); + (SDKConnect.getInstance as jest.Mock).mockReturnValue({ + loadDappConnections: jest.fn().mockResolvedValue({ + connection1: { + id: 'connection1', + originatorInfo: { url: 'test.com', title: 'Test' }, + validUntil: Date.now(), + scheme: 'scheme1', + }, + }), + addDappConnection: jest.fn().mockResolvedValue(null), + }); + + (Engine.context as any) = { + PermissionController: { + requestPermissions: jest.fn().mockResolvedValue(null), + getPermissions: jest + .fn() + .mockReturnValue({ eth_accounts: { caveats: [{ value: [] }] } }), + }, + KeyringController: { unlock: jest.fn() }, + NetworkController: { state: { providerConfig: { chainId: '0x1' } } }, + PreferencesController: { state: { selectedAddress: '0xAddress' } }, + }; + + (Linking.openURL as jest.Mock).mockResolvedValue(null); + service = new DeeplinkProtocolService(); + }); + + describe('init', () => { + it('should initialize and load connections', async () => { + const spy = jest.spyOn(SDKConnect.getInstance(), 'loadDappConnections'); + await service.init(); + expect(spy).toHaveBeenCalled(); + expect(service.isInitialized).toBe(true); + }); + + it('should handle initialization error', async () => { + ( + SDKConnect.getInstance().loadDappConnections as jest.Mock + ).mockRejectedValue(new Error('Failed to load connections')); + await service.init().catch(() => { + expect(service.isInitialized).toBe(false); + expect(Logger.log).toHaveBeenCalledWith( + expect.any(Error), + 'DeeplinkProtocolService:: error initializing', + ); + }); + }); + + it('should initialize with raw connections', async () => { + await service.init(); + expect(service.connections.connection1).toBeDefined(); + }); + }); + + describe('setupBridge', () => { + it('should set up a bridge for the client', () => { + const clientInfo = { + clientId: 'client1', + originatorInfo: { + url: 'test.com', + title: 'Test', + platform: 'test', + dappId: 'dappId', + }, + connected: false, + validUntil: Date.now(), + scheme: 'test', + }; + service.setupBridge(clientInfo); + expect(service.bridgeByClientId[clientInfo.clientId]).toBeInstanceOf( + BackgroundBridge, + ); + }); + + it('should return early if bridge already exists', () => { + const clientInfo = { + clientId: 'client1', + originatorInfo: { + url: 'test.com', + title: 'Test', + platform: 'test', + dappId: 'dappId', + }, + connected: false, + validUntil: Date.now(), + scheme: 'test', + }; + service.bridgeByClientId.client1 = {} as BackgroundBridge; + const setupBridgeSpy = jest.spyOn(service as any, 'setupBridge'); + service.setupBridge(clientInfo); + expect(setupBridgeSpy).toHaveReturned(); + }); + }); + describe('sendMessage', () => { + it('should handle sending messages correctly', async () => { + service.rpcQueueManager.getId = jest.fn().mockReturnValue('rpcMethod'); + service.batchRPCManager.getById = jest.fn().mockReturnValue(null); + service.rpcQueueManager.isEmpty = jest.fn().mockReturnValue(true); + service.rpcQueueManager.remove = jest.fn(); // Mock the remove method + + await service.sendMessage({ data: { id: '1' } }, true); + expect(service.rpcQueueManager.remove).toHaveBeenCalledWith('1'); + }); + + it('should handle batch RPC responses', async () => { + const mockChainRPCs = [{ id: '1' }]; + const mockMessage = { data: { id: '1', error: null } }; + service.batchRPCManager.getById = jest + .fn() + .mockReturnValue(mockChainRPCs); + (handleBatchRpcResponse as jest.Mock).mockResolvedValue(true); + + service.currentClientId = 'client1'; + service.bridgeByClientId.client1 = new BackgroundBridge({ + webview: null, + channelId: 'client1', + isMMSDK: true, + url: 'test-url', + isRemoteConn: true, + sendMessage: jest.fn(), + } as any); + + await service.sendMessage(mockMessage, true); + expect(handleBatchRpcResponse).toHaveBeenCalledWith( + expect.objectContaining({ + chainRpcs: mockChainRPCs, + msg: mockMessage, + backgroundBridge: expect.any(BackgroundBridge), + batchRPCManager: expect.anything(), + sendMessage: expect.any(Function), + }), + ); + }); + + it('should handle error in message data', async () => { + const mockMessage = { data: { id: '1', error: new Error('Test error') } }; + const openDeeplinkSpy = jest.spyOn(service, 'openDeeplink'); + + service.currentClientId = 'client1'; + service.bridgeByClientId.client1 = new BackgroundBridge({ + webview: null, + channelId: 'client1', + isMMSDK: true, + url: 'test-url', + isRemoteConn: true, + sendMessage: jest.fn(), + } as any); + + await service.sendMessage(mockMessage, true); + expect(openDeeplinkSpy).toHaveBeenCalledWith({ + message: mockMessage, + clientId: 'client1', + }); + }); + + it('should skip goBack if no rpc method and forceRedirect is not true', async () => { + const mockMessage = { data: { id: '1' } }; + const devLoggerSpy = jest.spyOn(DevLogger, 'log'); + + service.rpcQueueManager.getId = jest.fn().mockReturnValue(undefined); + service.rpcQueueManager.isEmpty = jest.fn().mockReturnValue(true); + + await service.sendMessage(mockMessage); + expect(devLoggerSpy).toHaveBeenCalledWith( + expect.stringContaining( + 'no rpc method --- rpcMethod=undefined forceRedirect=undefined --- skip goBack()', + ), + ); + }); + + it('should handle non-final batch RPC response and error in message data', async () => { + const mockChainRPCs = [{ id: '1' }]; + const mockMessage = { data: { id: '1', error: new Error('Test error') } }; + const devLoggerSpy = jest.spyOn(DevLogger, 'log'); + const openDeeplinkSpy = jest.spyOn(service, 'openDeeplink'); + service.batchRPCManager.getById = jest + .fn() + .mockReturnValue(mockChainRPCs); + (handleBatchRpcResponse as jest.Mock).mockResolvedValue(false); + + service.currentClientId = 'client1'; + service.bridgeByClientId.client1 = new BackgroundBridge({ + webview: null, + channelId: 'client1', + isMMSDK: true, + url: 'test-url', + isRemoteConn: true, + sendMessage: jest.fn(), + } as any); + + service.rpcQueueManager.remove = jest.fn(); + + await service.sendMessage(mockMessage, true); + expect(devLoggerSpy).toHaveBeenCalledWith( + expect.stringContaining('NOT last rpc --- skip goBack()'), + mockChainRPCs, + ); + expect(service.rpcQueueManager.remove).toHaveBeenCalledWith('1'); + expect(openDeeplinkSpy).toHaveBeenCalledWith({ + message: mockMessage, + clientId: 'client1', + }); + }); + + it('should update connection state and skip bridge setup if session exists', async () => { + const connectionParams = { + dappPublicKey: 'key', + url: 'url', + scheme: 'scheme', + channelId: 'channel1', + originatorInfo: Buffer.from( + JSON.stringify({ + originatorInfo: { + url: 'test.com', + title: 'Test', + platform: 'test', + dappId: 'dappId', + }, + }), + ).toString('base64'), + }; + service.connections.channel1 = { + clientId: 'channel1', + originatorInfo: { + url: 'test.com', + title: 'Test', + platform: 'test', + dappId: 'dappId', + }, + connected: false, + validUntil: Date.now(), + scheme: 'scheme', + }; + + await service.handleConnection(connectionParams); + + expect(service.connections.channel1.connected).toBe(true); + }); + }); + + describe('openDeeplink', () => { + it('should open a deeplink with the provided message', async () => { + const spy = jest.spyOn(Linking, 'openURL'); + await service.openDeeplink({ + message: { test: 'test' }, + clientId: 'client1', + }); + expect(spy).toHaveBeenCalled(); + }); + }); + + describe('checkPermission', () => { + it('should request permissions', async () => { + const spy = jest.spyOn( + Engine.context.PermissionController, + 'requestPermissions', + ); + await service.checkPermission({ + channelId: 'channel1', + originatorInfo: { + url: 'test.com', + title: 'Test', + platform: 'test', + dappId: 'dappId', + }, + }); + expect(spy).toHaveBeenCalled(); + }); + }); + + describe('handleConnection', () => { + it('should handle a new connection', async () => { + const connectionParams = { + dappPublicKey: 'key', + url: 'url', + scheme: 'scheme', + channelId: 'channel1', + originatorInfo: Buffer.from( + JSON.stringify({ + originatorInfo: { url: 'test.com', title: 'Test' }, + }), + ).toString('base64'), + }; + await service.handleConnection(connectionParams); + expect(service.connections.connection1).toBeDefined(); + }); + }); + + describe('processDappRpcRequest', () => { + it('should process a dapp RPC request', async () => { + const params = { + dappPublicKey: 'key', + url: 'url', + scheme: 'scheme', + channelId: 'channel1', + originatorInfo: 'info', + request: JSON.stringify({ id: '1', method: 'test', params: [] }), + }; + service.bridgeByClientId.channel1 = { onMessage: jest.fn() } as any; + await service.processDappRpcRequest(params); + expect(handleCustomRpcCalls).toHaveBeenCalled(); + }); + }); + + describe('handleMessage', () => { + it('should handle an incoming message', () => { + const params = { + dappPublicKey: 'key', + url: 'url', + message: Buffer.from( + JSON.stringify({ id: '1', method: 'test', params: [] }), + ).toString('base64'), + channelId: 'channel1', + scheme: 'scheme', + account: '0xAddress@1', + }; + service.handleMessage(params); + expect(DevLogger.log).toHaveBeenCalled(); + }); + }); + + describe('removeConnection', () => { + it('should remove a connection', () => { + service.connections.channel1 = {} as any; + service.removeConnection('channel1'); + expect(service.connections.channel1).toBeUndefined(); + }); + }); +}); diff --git a/app/core/SDKConnect/SDKDeeplinkProtocol/DeeplinkProtocolService.ts b/app/core/SDKConnect/SDKDeeplinkProtocol/DeeplinkProtocolService.ts index 18910b55c67..e605c4f1049 100644 --- a/app/core/SDKConnect/SDKDeeplinkProtocol/DeeplinkProtocolService.ts +++ b/app/core/SDKConnect/SDKDeeplinkProtocol/DeeplinkProtocolService.ts @@ -10,6 +10,7 @@ import Engine from '../../../core/Engine'; import Logger from '../../../util/Logger'; import BackgroundBridge from '../../BackgroundBridge/BackgroundBridge'; import { DappClient, DappConnections } from '../AndroidSDK/dapp-sdk-types'; +import getDefaultBridgeParams from '../AndroidSDK/getDefaultBridgeParams'; import BatchRPCManager from '../BatchRPCManager'; import RPCQueueManager from '../RPCQueueManager'; import SDKConnect from '../SDKConnect'; @@ -22,20 +23,19 @@ import handleBatchRpcResponse from '../handlers/handleBatchRpcResponse'; import handleCustomRpcCalls from '../handlers/handleCustomRpcCalls'; import DevLogger from '../utils/DevLogger'; import { wait, waitForKeychainUnlocked } from '../utils/wait.util'; -import getDefaultBridgeParams from '../AndroidSDK/getDefaultBridgeParams'; export default class DeeplinkProtocolService { - private connections: DappConnections = {}; - private bridgeByClientId: { [clientId: string]: BackgroundBridge } = {}; - private rpcQueueManager = new RPCQueueManager(); - private batchRPCManager: BatchRPCManager = new BatchRPCManager('deeplink'); + public connections: DappConnections = {}; + public bridgeByClientId: { [clientId: string]: BackgroundBridge } = {}; + public rpcQueueManager = new RPCQueueManager(); + public batchRPCManager: BatchRPCManager = new BatchRPCManager('deeplink'); // To keep track in order to get the associated bridge to handle batch rpc calls - private currentClientId?: string; - private dappPublicKeyByClientId: { + public currentClientId?: string; + public dappPublicKeyByClientId: { [clientId: string]: string; } = {}; - private isInitialized = false; + public isInitialized = false; public constructor() { if (!this.isInitialized) { @@ -51,7 +51,7 @@ export default class DeeplinkProtocolService { } } - private async init() { + public async init() { if (this.isInitialized) { return; } @@ -60,9 +60,6 @@ export default class DeeplinkProtocolService { if (rawConnections) { Object.values(rawConnections).forEach((connection) => { - DevLogger.log( - `DeeplinkProtocolService::init recover client: ${connection.id}`, - ); const clientInfo = { connected: false, clientId: connection.id, @@ -75,14 +72,10 @@ export default class DeeplinkProtocolService { this.setupBridge(clientInfo); }); - } else { - DevLogger.log( - `DeeplinkProtocolService::init no previous connections found`, - ); } } - private setupBridge(clientInfo: DappClient) { + public setupBridge(clientInfo: DappClient) { DevLogger.log( `DeeplinkProtocolService::setupBridge for id=${ clientInfo.clientId @@ -145,11 +138,6 @@ export default class DeeplinkProtocolService { sendMessage: ({ msg }) => this.sendMessage(msg), }); - DevLogger.log( - `DeeplinkProtocolService::sendMessage isLastRpc=${isLastRpcOrError}`, - chainRPCs, - ); - const hasError = !!message?.data?.error; if (!isLastRpcOrError) { @@ -172,10 +160,6 @@ export default class DeeplinkProtocolService { // Always set the method to metamask_batch otherwise it may not have been set correctly because of the batch rpc flow. rpcMethod = RPC_METHODS.METAMASK_BATCH; - - DevLogger.log( - `DeeplinkProtocolService::sendMessage chainRPCs=${chainRPCs} COMPLETED!`, - ); } this.rpcQueueManager.remove(id); @@ -229,7 +213,7 @@ export default class DeeplinkProtocolService { } } - private async openDeeplink({ + public async openDeeplink({ message, clientId, scheme, @@ -264,7 +248,7 @@ export default class DeeplinkProtocolService { } } - private async checkPermission({ + public async checkPermission({ channelId, }: { originatorInfo: OriginatorInfo; @@ -299,8 +283,6 @@ export default class DeeplinkProtocolService { this.dappPublicKeyByClientId[params.channelId] = params.dappPublicKey; - Logger.log('DeeplinkProtocolService::handleConnection params', params); - const decodedOriginatorInfo = Buffer.from( params.originatorInfo, 'base64', @@ -308,17 +290,8 @@ export default class DeeplinkProtocolService { const originatorInfoJson = JSON.parse(decodedOriginatorInfo); - DevLogger.log( - `DeeplinkProtocolService::handleConnection originatorInfoJson`, - originatorInfoJson, - ); - const originatorInfo = originatorInfoJson.originatorInfo; - Logger.log( - `DeeplinkProtocolService::originatorInfo: ${originatorInfo.url} ${originatorInfo.title}`, - ); - const clientInfo: DappClient = { clientId: params.channelId, originatorInfo, @@ -329,17 +302,11 @@ export default class DeeplinkProtocolService { this.currentClientId = params.channelId; - DevLogger.log(`DeeplinkProtocolService::clients_connected`, clientInfo); - const isSessionExists = this.connections?.[clientInfo.clientId]; if (isSessionExists) { // Skip existing client -- bridge has been setup - Logger.log( - `DeeplinkProtocolService::clients_connected - existing client, sending ready`, - ); - // Update connected state this.connections[clientInfo.clientId] = { ...this.connections[clientInfo.clientId], @@ -368,6 +335,16 @@ export default class DeeplinkProtocolService { return; } + await SDKConnect.getInstance().addDappConnection({ + id: clientInfo.clientId, + lastAuthorized: Date.now(), + origin: AppConstants.MM_SDK.IOS_SDK, + originatorInfo: clientInfo.originatorInfo, + otherPublicKey: this.dappPublicKeyByClientId[clientInfo.clientId], + validUntil: Date.now() + DEFAULT_SESSION_TIMEOUT_MS, + scheme: clientInfo.scheme, + }); + const handleEventAsync = async () => { const keyringController = ( Engine.context as { KeyringController: KeyringController } @@ -412,8 +389,6 @@ export default class DeeplinkProtocolService { }); } - DevLogger.log(`DeeplinkProtocolService::sendMessage 2`); - if (params.request) { await this.processDappRpcRequest(params); @@ -461,13 +436,6 @@ export default class DeeplinkProtocolService { name: 'metamask-provider', }; - // TODO: Remove this log after testing - DevLogger.log( - `DeeplinkProtocolService::sendMessage handleEventAsync hasError ===> sending deeplink`, - message, - this.currentClientId, - ); - this.openDeeplink({ message, clientId: this.currentClientId ?? '', @@ -486,7 +454,7 @@ export default class DeeplinkProtocolService { }); } - private async processDappRpcRequest(params: { + public async processDappRpcRequest(params: { dappPublicKey: string; url: string; scheme: string; @@ -556,7 +524,13 @@ export default class DeeplinkProtocolService { this.currentClientId ?? '', ); - const connectedAddresses = permissions?.eth_accounts?.caveats?.[0] + const preferencesController = ( + Engine.context as { PreferencesController: PreferencesController } + ).PreferencesController; + + const selectedAddress = preferencesController.state.selectedAddress; + + let connectedAddresses = permissions?.eth_accounts?.caveats?.[0] ?.value as string[]; DevLogger.log( @@ -564,7 +538,29 @@ export default class DeeplinkProtocolService { connectedAddresses, ); - return connectedAddresses ?? []; + if (!Array.isArray(connectedAddresses)) { + return []; + } + + const lowerCaseConnectedAddresses = connectedAddresses.map((address) => + address.toLowerCase(), + ); + + const isPartOfConnectedAddresses = lowerCaseConnectedAddresses.includes( + selectedAddress.toLowerCase(), + ); + + if (isPartOfConnectedAddresses) { + // Create a new array with selectedAddress at the first position + connectedAddresses = [ + selectedAddress, + ...connectedAddresses.filter( + (address) => address.toLowerCase() !== selectedAddress.toLowerCase(), + ), + ]; + } + + return connectedAddresses; } public getSelectedAddress() { @@ -590,7 +586,25 @@ export default class DeeplinkProtocolService { message: string; channelId: string; scheme: string; + account: string; // account@chainid }) { + let walletSelectedAddress = ''; + let walletSelectedChainId = ''; + let dappAccountChainId = ''; + let dappAccountAddress = ''; + + if (!params.account?.includes('@')) { + DevLogger.log( + `DeeplinkProtocolService:: handleMessage invalid params.account format ${params.account}`, + ); + } else { + const account = params.account.split('@'); + walletSelectedAddress = this.getSelectedAddress(); + walletSelectedChainId = this.getChainId(); + dappAccountChainId = account[1]; + dappAccountAddress = account[0]; + } + DevLogger.log( 'DeeplinkProtocolService:: handleMessage params from deeplink', params, @@ -632,6 +646,36 @@ export default class DeeplinkProtocolService { const message = JSON.parse(parsedMessage); // handle message and redirect to corresponding bridge DevLogger.log('DeeplinkProtocolService:: parsed message:-', message); data = message; + + const isAccountChanged = dappAccountAddress !== walletSelectedAddress; + const isChainChanged = dappAccountChainId !== walletSelectedChainId; + + if (isAccountChanged || isChainChanged) { + this.sendMessage( + { + data: { + id: data.id, + accounts: this.getSelectedAccounts(), + chainId: this.getChainId(), + error: { + code: -32602, + message: + 'The selected account or chain has changed. Please try again.', + }, + jsonrpc: '2.0', + }, + name: 'metamask-provider', + }, + true, + ).catch((err) => { + Logger.log( + err, + `DeeplinkProtocolService::onMessageReceived error sending jsonrpc error message to client ${sessionId}`, + ); + }); + + return; + } } catch (error) { Logger.log( error,