diff --git a/src/background/providers/ChainAgnosticProvider.test.ts b/src/background/providers/ChainAgnosticProvider.test.ts new file mode 100644 index 000000000..59fb4715f --- /dev/null +++ b/src/background/providers/ChainAgnosticProvider.test.ts @@ -0,0 +1,220 @@ +import { ethErrors } from 'eth-rpc-errors'; +import AutoPairingPostMessageConnection from '../utils/messaging/AutoPairingPostMessageConnection'; +import { ChainAgnosticProvider } from './ChainAgnosticProvider'; +import onDomReady from './utils/onDomReady'; +import { DAppProviderRequest } from '../connections/dAppConnection/models'; + +jest.mock('../utils/messaging/AutoPairingPostMessageConnection', () => { + const mocks = { + connect: jest.fn().mockResolvedValue(undefined), + on: jest.fn(), + request: jest.fn().mockResolvedValue({}), + }; + return jest.fn().mockReturnValue(mocks); +}); + +export const matchingPayload = (payload) => + expect.objectContaining({ + data: expect.objectContaining(payload), + }); + +jest.mock('./utils/onDomReady'); +jest.mock('../utils/messaging/AutoPairingPostMessageConnection', () => { + const mocks = { + connect: jest.fn().mockResolvedValue(undefined), + on: jest.fn(), + request: jest.fn().mockResolvedValue({}), + }; + return jest.fn().mockReturnValue(mocks); +}); +describe('src/background/providers/ChainAgnosticProvider', () => { + const channelMock = new AutoPairingPostMessageConnection(false); + + describe('initialization', () => { + it('should connect to the backgroundscript', async () => { + new ChainAgnosticProvider(channelMock); + + expect(channelMock.connect).toHaveBeenCalled(); + expect(channelMock.request).not.toHaveBeenCalled(); + }); + it('should wait for message channel to be connected', async () => { + const mockedChannel = new AutoPairingPostMessageConnection(false); + + const provider = new ChainAgnosticProvider(channelMock); + + await new Promise(process.nextTick); + + (onDomReady as jest.Mock).mock.calls[0][0](); + + expect(mockedChannel.connect).toHaveBeenCalled(); + expect(mockedChannel.request).not.toHaveBeenCalled(); + + await provider.request({ + data: { method: 'some-method', params: [{ param1: 1 }] }, + sessionId: '00000000-0000-0000-0000-000000000000', + chainId: '1', + }); + expect(mockedChannel.request).toHaveBeenCalled(); + }); + it('should call the `DOMAIN_METADATA_METHOD` adter domReady', async () => { + new ChainAgnosticProvider(channelMock); + await new Promise(process.nextTick); + expect(channelMock.request).toHaveBeenCalledTimes(0); + (onDomReady as jest.Mock).mock.calls[0][0](); + await new Promise(process.nextTick); + + expect(channelMock.request).toHaveBeenCalledTimes(1); + + expect(channelMock.request).toHaveBeenCalledWith( + // matchingPayload({ + // method: DAppProviderRequest.INIT_DAPP_STATE, + // }) + expect.objectContaining({ + params: expect.objectContaining({ + request: expect.objectContaining({ + method: DAppProviderRequest.DOMAIN_METADATA_METHOD, + }), + }), + }) + ); + }); + }); + + describe('request', () => { + it('should collect pending requests till the dom is ready', async () => { + const provider = new ChainAgnosticProvider(channelMock); + // wait for init to finish + await new Promise(process.nextTick); + + expect(channelMock.request).toHaveBeenCalledTimes(0); + + (channelMock.request as jest.Mock).mockResolvedValue('success'); + const rpcResultCallback = jest.fn(); + provider + .request({ + data: { + method: 'some-method', + params: [{ param1: 1 }], + }, + }) + .then(rpcResultCallback); + await new Promise(process.nextTick); + + expect(channelMock.request).toHaveBeenCalledTimes(0); + + // domReady triggers sending pending requests as well + (onDomReady as jest.Mock).mock.calls[0][0](); + await new Promise(process.nextTick); + + expect(channelMock.request).toHaveBeenCalledTimes(2); + + expect(rpcResultCallback).toHaveBeenCalledWith('success'); + }); + it('should use the rate limits on `eth_requestAccounts` requests', async () => { + const provider = new ChainAgnosticProvider(channelMock); + (channelMock.request as jest.Mock).mockResolvedValue('success'); + + await new Promise(process.nextTick); + + (onDomReady as jest.Mock).mock.calls[0][0](); + + const firstCallCallback = jest.fn(); + const secondCallCallback = jest.fn(); + provider + .request({ + data: { method: 'eth_requestAccounts' }, + } as any) + .then(firstCallCallback) + .catch(firstCallCallback); + provider + .request({ + data: { method: 'eth_requestAccounts' }, + } as any) + .then(secondCallCallback) + .catch(secondCallCallback); + + await new Promise(process.nextTick); + expect(firstCallCallback).toHaveBeenCalledWith('success'); + expect(secondCallCallback).toHaveBeenCalledWith( + ethErrors.rpc.resourceUnavailable( + `Request of type eth_requestAccounts already pending for origin. Please wait.` + ) + ); + }); + it('shoud not use the rate limits on `random_method` requests', async () => { + const provider = new ChainAgnosticProvider(channelMock); + (channelMock.request as jest.Mock).mockResolvedValue('success'); + + await new Promise(process.nextTick); + + (onDomReady as jest.Mock).mock.calls[0][0](); + + const firstCallCallback = jest.fn(); + const secondCallCallback = jest.fn(); + provider + .request({ + data: { method: 'random_method' }, + } as any) + .then(firstCallCallback) + .catch(firstCallCallback); + provider + .request({ + data: { method: 'random_method' }, + } as any) + .then(secondCallCallback) + .catch(secondCallCallback); + + await new Promise(process.nextTick); + expect(firstCallCallback).toHaveBeenCalledWith('success'); + expect(secondCallCallback).toHaveBeenCalledWith('success'); + }); + + it('should call the request of the connection', async () => { + const provider = new ChainAgnosticProvider(channelMock); + (channelMock.request as jest.Mock).mockResolvedValueOnce('success'); + + await new Promise(process.nextTick); + + (onDomReady as jest.Mock).mock.calls[0][0](); + + await provider.request({ + data: { method: 'some-method', params: [{ param1: 1 }] }, + sessionId: '00000000-0000-0000-0000-000000000000', + chainId: '1', + }); + expect(channelMock.request).toHaveBeenCalled(); + }); + describe('CAIP-27', () => { + it('should wrap the incoming request into CAIP-27 envelope and reuses the provided ID', async () => { + const provider = new ChainAgnosticProvider(channelMock); + // response for the actual call + (channelMock.request as jest.Mock).mockResolvedValueOnce('success'); + + await new Promise(process.nextTick); + + (onDomReady as jest.Mock).mock.calls[0][0](); + + provider.request({ + data: { method: 'some-method', params: [{ param1: 1 }] }, + sessionId: '00000000-0000-0000-0000-000000000000', + chainId: '1', + }); + + await new Promise(process.nextTick); + + expect(channelMock.request).toHaveBeenCalledWith({ + jsonrpc: '2.0', + method: 'provider_request', + params: { + scope: 'eip155:1', + sessionId: '00000000-0000-0000-0000-000000000000', + request: { + method: 'some-method', + params: [{ param1: 1 }], + }, + }, + }); + }); + }); + }); +}); diff --git a/src/background/providers/ChainAgnosticProvider.ts b/src/background/providers/ChainAgnosticProvider.ts new file mode 100644 index 000000000..3df1b647b --- /dev/null +++ b/src/background/providers/ChainAgnosticProvider.ts @@ -0,0 +1,113 @@ +import EventEmitter from 'events'; +import { + DAppProviderRequest, + JsonRpcRequest, + JsonRpcRequestPayload, +} from '../connections/dAppConnection/models'; +import { PartialBy } from '../models'; +import { ethErrors, serializeError } from 'eth-rpc-errors'; +import AbstractConnection from '../utils/messaging/AbstractConnection'; +import { ChainId } from '@avalabs/core-chains-sdk'; +import RequestRatelimiter from './utils/RequestRatelimiter'; +import { + InitializationStep, + ProviderReadyPromise, +} from './utils/ProviderReadyPromise'; +import onDomReady from './utils/onDomReady'; +import { getSiteMetadata } from './utils/getSiteMetadata'; + +export class ChainAgnosticProvider extends EventEmitter { + #contentScriptConnection: AbstractConnection; + #providerReadyPromise = new ProviderReadyPromise([ + InitializationStep.DOMAIN_METADATA_SENT, + ]); + + #requestRateLimiter = new RequestRatelimiter([ + 'eth_requestAccounts', + 'avalanche_selectWallet', + ]); + + constructor(connection) { + super(); + connection.connect(); + this.#contentScriptConnection = connection; + this.#init(); + } + + async #init() { + await this.#contentScriptConnection.connect(); + + onDomReady(async () => { + const domainMetadata = await getSiteMetadata(); + + this.#request({ + data: { + method: DAppProviderRequest.DOMAIN_METADATA_METHOD, + params: domainMetadata, + }, + }); + + this.#providerReadyPromise.check(InitializationStep.DOMAIN_METADATA_SENT); + }); + } + + #request = async ({ + data, + sessionId, + chainId, + }: { + data: PartialBy; + sessionId?: string; + chainId?: string | null; + }) => { + if (!data) { + throw ethErrors.rpc.invalidRequest(); + } + + const result = this.#contentScriptConnection + .request({ + method: 'provider_request', + jsonrpc: '2.0', + params: { + scope: `eip155:${ + chainId ? parseInt(chainId) : ChainId.AVALANCHE_MAINNET_ID + }`, + sessionId, + request: { + params: [], + ...data, + }, + }, + } as JsonRpcRequest) + .catch((err) => { + // If the error is already a JsonRPCErorr do not serialize them. + // eth-rpc-errors always wraps errors if they have an unkown error code + // even if the code is valid like 4902 for unrecognized chain ID. + if (!!err.code && Number.isInteger(err.code) && !!err.message) { + throw err; + } + throw serializeError(err); + }); + return result; + }; + + request = async ({ + data, + sessionId, + chainId, + }: { + data: PartialBy; + sessionId?: string; + chainId?: string | null; + }) => { + return this.#providerReadyPromise.call(() => { + return this.#requestRateLimiter.call(data.method, () => + this.#request({ data, chainId, sessionId }) + ); + }); + }; + + subscribeToMessage = (callback) => { + this.#contentScriptConnection.on('message', callback); + }; +} diff --git a/src/background/providers/CoreProvider.test.ts b/src/background/providers/CoreProvider.test.ts index 0d6d8da13..12dc55fb2 100644 --- a/src/background/providers/CoreProvider.test.ts +++ b/src/background/providers/CoreProvider.test.ts @@ -1,8 +1,9 @@ import { ethErrors } from 'eth-rpc-errors'; import { CoreProvider } from './CoreProvider'; -import onDomReady from './utils/onDomReady'; import { DAppProviderRequest } from '../connections/dAppConnection/models'; import AutoPairingPostMessageConnection from '../utils/messaging/AutoPairingPostMessageConnection'; +import { EventNames } from './models'; +import { matchingPayload } from './ChainAgnosticProvider.test'; jest.mock('../utils/messaging/AutoPairingPostMessageConnection', () => { const mocks = { @@ -15,25 +16,18 @@ jest.mock('../utils/messaging/AutoPairingPostMessageConnection', () => { jest.mock('./utils/onDomReady'); -const matchingPayload = (payload) => - expect.objectContaining({ - params: expect.objectContaining({ - request: expect.objectContaining(payload), - }), - }); - +const channelMockResolvedValue = { + isUnlocked: true, + chainId: '0x1', + networkVersion: '1', + accounts: ['0x00000'], +}; describe('src/background/providers/CoreProvider', () => { const channelMock = new AutoPairingPostMessageConnection(false); + const addEventListenerSpy = jest.spyOn(window, 'addEventListener'); beforeEach(() => { jest.mocked(channelMock.connect).mockResolvedValueOnce(undefined); - - (channelMock.request as jest.Mock).mockResolvedValueOnce({ - isUnlocked: true, - chainId: '0x1', - networkVersion: '1', - accounts: ['0x00000'], - }); }); afterEach(() => { @@ -42,7 +36,7 @@ describe('src/background/providers/CoreProvider', () => { describe('EIP-5749', () => { it('sets the ProviderInfo', () => { - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); expect(provider.info).toEqual({ description: 'EVM_PROVIDER_INFO_DESCRIPTION', icon: 'EVM_PROVIDER_INFO_ICON', @@ -53,58 +47,22 @@ describe('src/background/providers/CoreProvider', () => { }); }); - describe('CAIP-27', () => { - let provider; - - beforeEach(async () => { - provider = new CoreProvider({ connection: channelMock }); - - // wait for init to finish - await new Promise(process.nextTick); - - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce({}); - - // domReady to allow requests through - (onDomReady as jest.Mock).mock.calls[0][0](); - - await new Promise(process.nextTick); - }); - - it('wraps incoming requests into CAIP-27 envelope and reuses the provided ID', async () => { - // response for the actual call - (channelMock.request as jest.Mock).mockResolvedValueOnce('success'); - - provider.send( - { - method: 'some-method', - params: [{ param1: 1 }], - }, - jest.fn() - ); - - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledWith({ - jsonrpc: '2.0', - method: 'provider_request', - params: { - scope: 'eip155:1', - sessionId: '00000000-0000-0000-0000-000000000000', - request: { - method: 'some-method', - params: [{ param1: 1 }], - }, - }, - }); - }); - }); - describe('EIP-1193', () => { describe('request', () => { it('collects pending requests till the dom is ready', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); @@ -115,10 +73,8 @@ describe('src/background/providers/CoreProvider', () => { }) ); - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce({}); // response for 'some-method' - (channelMock.request as jest.Mock).mockResolvedValueOnce('success'); + (channelMock.request as jest.Mock).mockResolvedValue('success'); const rpcResultCallback = jest.fn(); provider .request({ @@ -128,14 +84,11 @@ describe('src/background/providers/CoreProvider', () => { .then(rpcResultCallback); await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); + expect(channelMock.request).toHaveBeenCalledTimes(2); - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); await new Promise(process.nextTick); - expect(channelMock.request).toHaveBeenCalledTimes(3); + expect(channelMock.request).toHaveBeenCalledTimes(2); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'some-method', @@ -146,64 +99,20 @@ describe('src/background/providers/CoreProvider', () => { expect(rpcResultCallback).toHaveBeenCalledWith('success'); }); - it('rate limits `eth_requestAccounts` requests', async () => { - const provider = new CoreProvider({ connection: channelMock }); - - // wait for init to finish - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(1); - expect(channelMock.request).toHaveBeenCalledWith( - matchingPayload({ - method: DAppProviderRequest.INIT_DAPP_STATE, - }) - ); - - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce({}); - // response for 'eth_requestAccounts' - (channelMock.request as jest.Mock).mockResolvedValue('success'); - const firstCallCallback = jest.fn(); - const secondCallCallback = jest.fn(); - provider - .request({ - method: 'eth_requestAccounts', - }) - .then(firstCallCallback) - .catch(firstCallCallback); - provider - .request({ - method: 'eth_requestAccounts', - }) - .then(secondCallCallback) - .catch(secondCallCallback); - await new Promise(process.nextTick); - - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); - expect(channelMock.request).toHaveBeenCalledWith( - matchingPayload({ - method: 'eth_requestAccounts', - }) - ); - - expect(firstCallCallback).toHaveBeenCalledWith('success'); - expect(secondCallCallback).toHaveBeenCalledWith( - ethErrors.rpc.resourceUnavailable( - `Request of type eth_requestAccounts already pending for origin. Please wait.` - ) - ); - }); - it('always returns JSON RPC-compatible error', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); @@ -214,8 +123,6 @@ describe('src/background/providers/CoreProvider', () => { }) ); - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce({}); // response for 'eth_requestAccounts' (channelMock.request as jest.Mock).mockRejectedValueOnce( new Error('non RPC error') @@ -228,32 +135,32 @@ describe('src/background/providers/CoreProvider', () => { .catch(callCallback); await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); await new Promise(process.nextTick); - expect(channelMock.request).toHaveBeenCalledTimes(3); + expect(channelMock.request).toHaveBeenCalledTimes(2); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'eth_requestAccounts', }) ); - expect(callCallback).toHaveBeenCalledWith({ - code: -32603, - data: { - originalError: {}, - }, - message: 'non RPC error', - }); + expect(callCallback).toHaveBeenCalledWith(new Error('non RPC error')); }); it('does not double wraps JSON RPC errors', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); @@ -264,8 +171,6 @@ describe('src/background/providers/CoreProvider', () => { }) ); - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce({}); // response for 'eth_requestAccounts' (channelMock.request as jest.Mock).mockRejectedValueOnce({ code: 4902, @@ -281,14 +186,7 @@ describe('src/background/providers/CoreProvider', () => { .catch(callCallback); await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); + expect(channelMock.request).toHaveBeenCalledTimes(2); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'eth_requestAccounts', @@ -306,12 +204,27 @@ describe('src/background/providers/CoreProvider', () => { describe('events', () => { describe(`connect`, () => { - it('emits `connect` when chainId first set', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should emit `connect` when chainId first set', async () => { + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const connectSubscription = jest.fn(); provider.addListener('connect', connectSubscription); - // wait for init to finish + // // wait for init to finish await new Promise(process.nextTick); expect(channelMock.request).toHaveBeenCalledTimes(1); @@ -325,7 +238,7 @@ describe('src/background/providers/CoreProvider', () => { expect(connectSubscription).toHaveBeenCalledWith({ chainId: '0x1' }); }); - it('does not emit connect if chain is still loading', async () => { + it('should not emit connect if chain is still loading', async () => { (channelMock.request as jest.Mock).mockReset(); (channelMock.request as jest.Mock).mockResolvedValue({ isUnlocked: true, @@ -333,7 +246,19 @@ describe('src/background/providers/CoreProvider', () => { networkVersion: 'loading', accounts: ['0x00000'], }); - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const connectSubscription = jest.fn(); provider.addListener('connect', connectSubscription); @@ -359,8 +284,23 @@ describe('src/background/providers/CoreProvider', () => { }); }); - it('emits connect on re-connect after disconnected', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should emit connect on re-connect after disconnected', async () => { + const provider = new CoreProvider(); + (channelMock.request as jest.Mock).mockResolvedValue( + channelMockResolvedValue + ); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const connectSubscription = jest.fn(); const disconnectSubscription = jest.fn(); provider.addListener('connect', connectSubscription); @@ -395,7 +335,19 @@ describe('src/background/providers/CoreProvider', () => { describe('disconnect', () => { it('emits disconnect event with error', async () => { - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const disconnectSubscription = jest.fn(); provider.addListener('disconnect', disconnectSubscription); @@ -426,8 +378,20 @@ describe('src/background/providers/CoreProvider', () => { }); describe('chainChanged', () => { - it('does not emit `chainChanged` on initialization', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should not emit `chainChanged` on initialization', async () => { + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const chainChangedSubscription = jest.fn(); provider.addListener('chainChanged', chainChangedSubscription); @@ -445,7 +409,19 @@ describe('src/background/providers/CoreProvider', () => { networkVersion: 'loading', accounts: ['0x00000'], }); - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const chainChangedSubscription = jest.fn(); provider.addListener('chainChanged', chainChangedSubscription); @@ -453,8 +429,7 @@ describe('src/background/providers/CoreProvider', () => { await new Promise(process.nextTick); expect(chainChangedSubscription).not.toHaveBeenCalled(); - - (channelMock.on as jest.Mock).mock.calls[0][1]({ + (channelMock.on as jest.Mock).mock.calls[0]?.[1]({ method: 'chainChanged', params: { chainId: '0x1', networkVersion: '1' }, }); @@ -462,8 +437,24 @@ describe('src/background/providers/CoreProvider', () => { expect(chainChangedSubscription).toHaveBeenCalledWith('0x1'); }); - it('does not emit `chainChanged` when chain is set to the same value', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should not emit `chainChanged` when chain is set to the same value', async () => { + const provider = new CoreProvider(); + + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const chainChangedSubscription = jest.fn(); provider.addListener('chainChanged', chainChangedSubscription); @@ -486,7 +477,19 @@ describe('src/background/providers/CoreProvider', () => { }); it('emits `chainChanged` when chain is set to new value', async () => { - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const chainChangedSubscription = jest.fn(); provider.addListener('chainChanged', chainChangedSubscription); @@ -513,7 +516,19 @@ describe('src/background/providers/CoreProvider', () => { describe('accountsChanged', () => { it('emits `accountsChanged` on initialization', async () => { - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const accountsChangedSubscription = jest.fn(); provider.addListener('accountsChanged', accountsChangedSubscription); @@ -534,7 +549,19 @@ describe('src/background/providers/CoreProvider', () => { networkVersion: '1', accounts: undefined, }); - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const accountsChangedSubscription = jest.fn(); provider.addListener('accountsChanged', accountsChangedSubscription); @@ -545,8 +572,24 @@ describe('src/background/providers/CoreProvider', () => { expect(accountsChangedSubscription).toHaveBeenCalledWith([]); }); - it('does not emit `accountsChanged` when account is set to the same value', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should not emit `accountsChanged` when account is set to the same value', async () => { + const provider = new CoreProvider(); + + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const accountsChangedSubscription = jest.fn(); provider.addListener('accountsChanged', accountsChangedSubscription); @@ -563,8 +606,24 @@ describe('src/background/providers/CoreProvider', () => { expect(accountsChangedSubscription).toHaveBeenCalledTimes(1); }); - it('emits `accountsChanged` when account is set to new value', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should emit `accountsChanged` when account is set to new value', async () => { + const provider = new CoreProvider(); + + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const accountsChangedSubscription = jest.fn(); provider.addListener('accountsChanged', accountsChangedSubscription); @@ -588,16 +647,27 @@ describe('src/background/providers/CoreProvider', () => { describe('legacy', () => { describe('sendAsync', () => { - it('collects pending requests till the dom is ready', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + it('should call the requests correctly', async () => { + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce(undefined); // response for 'some-method' (channelMock.request as jest.Mock).mockResolvedValueOnce('success'); + // response for domain metadata send + (channelMock.request as jest.Mock).mockResolvedValueOnce(undefined); const rpcResultCallback = jest.fn(); provider.sendAsync( { @@ -609,13 +679,8 @@ describe('src/background/providers/CoreProvider', () => { await new Promise(process.nextTick); // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); + expect(channelMock.request).toHaveBeenCalledTimes(2); - expect(channelMock.request).toHaveBeenCalledTimes(3); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'some-method', @@ -629,66 +694,20 @@ describe('src/background/providers/CoreProvider', () => { }); }); - it('rate limits `eth_requestAccounts` requests', async () => { - const provider = new CoreProvider({ connection: channelMock }); - - // wait for init to finish - await new Promise(process.nextTick); - - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce({}); - // response for 'eth_requestAccounts' - (channelMock.request as jest.Mock).mockResolvedValue('success'); - const firstCallCallback = jest.fn(); - const secondCallCallback = jest.fn(); - provider.sendAsync( - { - method: 'eth_requestAccounts', - }, - firstCallCallback - ); - provider.sendAsync( - { - method: 'eth_requestAccounts', + it('should support batched requets', async () => { + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, }, - secondCallCallback - ); - await new Promise(process.nextTick); - - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); - expect(channelMock.request).toHaveBeenCalledWith( - matchingPayload({ - method: 'eth_requestAccounts', - }) - ); - - expect(firstCallCallback).toHaveBeenCalledWith(null, { - method: 'eth_requestAccounts', - result: 'success', }); - expect(secondCallCallback).toHaveBeenCalledWith( - ethErrors.rpc.resourceUnavailable( - `Request of type eth_requestAccounts already pending for origin. Please wait.` - ), - { - method: 'eth_requestAccounts', - error: ethErrors.rpc.resourceUnavailable( - `Request of type eth_requestAccounts already pending for origin. Please wait.` - ), - } - ); - }); - - it('supports batched requets', async () => { - const provider = new CoreProvider({ connection: channelMock }); - // wait for init to finish await new Promise(process.nextTick); @@ -708,16 +727,10 @@ describe('src/background/providers/CoreProvider', () => { ], rpcResultCallback ); - await new Promise(process.nextTick); - - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); await new Promise(process.nextTick); - expect(channelMock.request).toHaveBeenCalledTimes(4); + expect(channelMock.request).toHaveBeenCalledTimes(3); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'some-method', @@ -739,14 +752,23 @@ describe('src/background/providers/CoreProvider', () => { }); describe('send', () => { - it('collects pending requests till the dom is ready', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + it('should call the requests properly', async () => { + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce({}); // response for 'some-method' (channelMock.request as jest.Mock).mockResolvedValueOnce('success'); const rpcResultCallback = jest.fn(); @@ -759,14 +781,8 @@ describe('src/background/providers/CoreProvider', () => { ); await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); + expect(channelMock.request).toHaveBeenCalledTimes(2); - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'some-method', @@ -780,66 +796,20 @@ describe('src/background/providers/CoreProvider', () => { }); }); - it('rate limits `eth_requestAccounts` requests', async () => { - const provider = new CoreProvider({ connection: channelMock }); - - // wait for init to finish - await new Promise(process.nextTick); - - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce({}); - // response for 'eth_requestAccounts' - (channelMock.request as jest.Mock).mockResolvedValue('success'); - const firstCallCallback = jest.fn(); - const secondCallCallback = jest.fn(); - provider.send( - { - method: 'eth_requestAccounts', - }, - firstCallCallback - ); - provider.send( - { - method: 'eth_requestAccounts', + it('should support batched requets', async () => { + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, }, - secondCallCallback - ); - await new Promise(process.nextTick); - - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); - expect(channelMock.request).toHaveBeenCalledWith( - matchingPayload({ - method: 'eth_requestAccounts', - }) - ); - - expect(firstCallCallback).toHaveBeenCalledWith(null, { - method: 'eth_requestAccounts', - result: 'success', }); - expect(secondCallCallback).toHaveBeenCalledWith( - ethErrors.rpc.resourceUnavailable( - `Request of type eth_requestAccounts already pending for origin. Please wait.` - ), - { - method: 'eth_requestAccounts', - error: ethErrors.rpc.resourceUnavailable( - `Request of type eth_requestAccounts already pending for origin. Please wait.` - ), - } - ); - }); - - it('supports batched requets', async () => { - const provider = new CoreProvider({ connection: channelMock }); - // wait for init to finish await new Promise(process.nextTick); @@ -859,16 +829,10 @@ describe('src/background/providers/CoreProvider', () => { ], rpcResultCallback ); - await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); await new Promise(process.nextTick); - expect(channelMock.request).toHaveBeenCalledTimes(4); + expect(channelMock.request).toHaveBeenCalledTimes(3); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'some-method', @@ -888,9 +852,20 @@ describe('src/background/providers/CoreProvider', () => { ]); }); - it('supports method as the only param', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + it('should support method as the only param', async () => { + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); @@ -904,14 +879,7 @@ describe('src/background/providers/CoreProvider', () => { await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); + expect(channelMock.request).toHaveBeenCalledTimes(2); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'some-method', @@ -925,9 +893,20 @@ describe('src/background/providers/CoreProvider', () => { }); }); - it('supports method with params', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + it('should support method with params', async () => { + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); @@ -941,14 +920,7 @@ describe('src/background/providers/CoreProvider', () => { await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); + expect(channelMock.request).toHaveBeenCalledTimes(2); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'some-method', @@ -962,9 +934,25 @@ describe('src/background/providers/CoreProvider', () => { }); }); - it('returns eth_accounts response syncronously', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should return eth_accounts response syncronously', async () => { + const provider = new CoreProvider(); + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); @@ -978,9 +966,24 @@ describe('src/background/providers/CoreProvider', () => { }); }); - it('returns eth_coinbase response syncronously', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should return eth_coinbase response syncronously', async () => { + const provider = new CoreProvider(); + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); @@ -995,8 +998,19 @@ describe('src/background/providers/CoreProvider', () => { }); it('throws error if method not supported syncronously', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); @@ -1019,28 +1033,34 @@ describe('src/background/providers/CoreProvider', () => { }); describe('enable', () => { - it('collects pending requests till the dom is ready', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should call the requests properly', async () => { + const provider = new CoreProvider(); + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce(undefined); // response for 'some-method' (channelMock.request as jest.Mock).mockResolvedValueOnce(['0x0000']); const rpcResultCallback = jest.fn(); provider.enable().then(rpcResultCallback); await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); + expect(channelMock.request).toHaveBeenCalledTimes(2); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'eth_requestAccounts', @@ -1049,69 +1069,34 @@ describe('src/background/providers/CoreProvider', () => { expect(rpcResultCallback).toHaveBeenCalledWith(['0x0000']); }); - - it('rate limits enable calls', async () => { - const provider = new CoreProvider({ connection: channelMock }); - - // wait for init to finish - await new Promise(process.nextTick); - - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce(undefined); - // response for 'eth_requestAccounts' - (channelMock.request as jest.Mock).mockResolvedValue(['0x0000']); - const firstCallCallback = jest.fn(); - const secondCallCallback = jest.fn(); - provider.enable().then(firstCallCallback).catch(firstCallCallback); - provider.enable().then(secondCallCallback).catch(secondCallCallback); - await new Promise(process.nextTick); - - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); - expect(channelMock.request).toHaveBeenCalledWith( - matchingPayload({ - method: 'eth_requestAccounts', - }) - ); - - expect(firstCallCallback).toHaveBeenCalledWith(['0x0000']); - expect(secondCallCallback).toHaveBeenCalledWith( - ethErrors.rpc.resourceUnavailable( - `Request of type eth_requestAccounts already pending for origin. Please wait.` - ) - ); - }); }); describe('net_version', () => { it('supports net_version call', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + const provider = new CoreProvider(); + + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); // wait for init to finish await new Promise(process.nextTick); - // response for domain metadata send - (channelMock.request as jest.Mock).mockResolvedValueOnce(undefined); // response for 'some-method' (channelMock.request as jest.Mock).mockResolvedValueOnce('1'); const rpcResultCallback = jest.fn(); provider.net_version().then(rpcResultCallback); await new Promise(process.nextTick); - // no domReady happened yet, still only one call sent - expect(channelMock.request).toHaveBeenCalledTimes(1); - - // domReady triggers sending pending requests as well - (onDomReady as jest.Mock).mock.calls[0][0](); - await new Promise(process.nextTick); - - expect(channelMock.request).toHaveBeenCalledTimes(3); + expect(channelMock.request).toHaveBeenCalledTimes(2); expect(channelMock.request).toHaveBeenCalledWith( matchingPayload({ method: 'net_version', @@ -1123,8 +1108,24 @@ describe('src/background/providers/CoreProvider', () => { }); describe('close event', () => { - it('emits close event with error', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should emit close event with error', async () => { + const provider = new CoreProvider(); + + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const closeSubscription = jest.fn(); provider.on('close', closeSubscription); @@ -1155,8 +1156,21 @@ describe('src/background/providers/CoreProvider', () => { }); describe('networkChanged event', () => { - it('does not emit `networkChanged` on initialization', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should not emit `networkChanged` on initialization', async () => { + const provider = new CoreProvider(); + + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const networkChangedSubscription = jest.fn(); provider.addListener('networkChanged', networkChangedSubscription); @@ -1174,7 +1188,19 @@ describe('src/background/providers/CoreProvider', () => { networkVersion: 'loading', accounts: ['0x00000'], }); - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const networkChangedSubscription = jest.fn(); provider.addListener('networkChanged', networkChangedSubscription); @@ -1191,8 +1217,24 @@ describe('src/background/providers/CoreProvider', () => { expect(networkChangedSubscription).toHaveBeenCalledWith('1'); }); - it('does not emit `networkChanged` when chain is set to the same value', async () => { - const provider = new CoreProvider({ connection: channelMock }); + it('should not emit `networkChanged` when chain is set to the same value', async () => { + const provider = new CoreProvider(); + + (channelMock.request as jest.Mock).mockResolvedValueOnce( + channelMockResolvedValue + ); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const networkChangedSubscription = jest.fn(); provider.addListener('networkChanged', networkChangedSubscription); @@ -1215,7 +1257,19 @@ describe('src/background/providers/CoreProvider', () => { }); it('emits `chainChanged` when chain is set to new value', async () => { - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const networkChangedSubscription = jest.fn(); provider.addListener('networkChanged', networkChangedSubscription); @@ -1243,24 +1297,27 @@ describe('src/background/providers/CoreProvider', () => { }); describe('init', () => { - it('waits for message channel to be connected', async () => { - const mockedChannel = new AutoPairingPostMessageConnection(false); - let resolve; - const promise = new Promise((res) => { - resolve = res; + it('should call the event listener with the right event name', async () => { + new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, }); + expect(addEventListenerSpy).toHaveBeenCalledTimes(1); - jest.mocked(mockedChannel.connect).mockReturnValue(promise); - - new CoreProvider({ connection: mockedChannel }); - expect(mockedChannel.connect).toHaveBeenCalled(); - expect(mockedChannel.request).not.toHaveBeenCalled(); - - resolve(); - await new Promise(process.nextTick); - expect(mockedChannel.request).toHaveBeenCalled(); + expect(addEventListenerSpy).toHaveBeenCalledWith( + EventNames.CORE_WALLET_ANNOUNCE_PROVIDER, + expect.any(Function) + ); }); - it('loads provider state from the background', async () => { const mockedChannel = new AutoPairingPostMessageConnection(false); @@ -1271,7 +1328,19 @@ describe('src/background/providers/CoreProvider', () => { networkVersion: '1', accounts: ['0x00000'], }); - const provider = new CoreProvider({ connection: mockedChannel }); + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); const initializedSubscription = jest.fn(); provider.addListener('_initialized', initializedSubscription); await new Promise(process.nextTick); @@ -1298,8 +1367,19 @@ describe('src/background/providers/CoreProvider', () => { describe('Metamask compatibility', () => { it('supports _metamask.isUnlocked', async () => { - const provider = new CoreProvider({ connection: channelMock }); - + const provider = new CoreProvider(); + (addEventListenerSpy.mock.calls[0]?.[1] as any)({ + detail: { + provider: { + subscribeToMessage: jest.fn((callback) => { + return channelMock.on('message', callback); + }), + request: jest.fn((params) => { + return channelMock.request(params as any); + }), + }, + }, + }); expect(await provider._metamask.isUnlocked()).toBe(false); // wait for init to finish @@ -1307,11 +1387,11 @@ describe('src/background/providers/CoreProvider', () => { expect(await provider._metamask.isUnlocked()).toBe(true); }); it('isMetamask is true', () => { - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); expect(provider.isMetaMask).toBe(true); }); it('isAvalanche is true', async () => { - const provider = new CoreProvider({ connection: channelMock }); + const provider = new CoreProvider(); expect(provider.isAvalanche).toBe(true); }); }); diff --git a/src/background/providers/CoreProvider.ts b/src/background/providers/CoreProvider.ts index 93994e69e..4e497da0f 100644 --- a/src/background/providers/CoreProvider.ts +++ b/src/background/providers/CoreProvider.ts @@ -1,12 +1,10 @@ -import { ethErrors, serializeError } from 'eth-rpc-errors'; +import { ethErrors } from 'eth-rpc-errors'; import EventEmitter from 'events'; -import { getSiteMetadata } from './utils/getSiteMetadata'; -import onDomReady from './utils/onDomReady'; -import RequestRatelimiter from './utils/RequestRatelimiter'; import { - AccountsChangedEventData, - ChainChangedEventData, - UnlockStateChangedEventData, + EventNames, + type AccountsChangedEventData, + type ChainChangedEventData, + type UnlockStateChangedEventData, } from './models'; import { InitializationStep, @@ -14,12 +12,10 @@ import { } from './utils/ProviderReadyPromise'; import { DAppProviderRequest, - JsonRpcRequest, - JsonRpcRequestPayload, + type JsonRpcRequestPayload, } from '../connections/dAppConnection/models'; -import { PartialBy, ProviderInfo } from '../models'; -import AbstractConnection from '../utils/messaging/AbstractConnection'; -import { ChainId } from '@avalabs/core-chains-sdk'; +import type { PartialBy, ProviderInfo } from '../models'; +import { ChainAgnosticProvider } from './ChainAgnosticProvider'; interface ProviderState { accounts: string[] | null; @@ -29,12 +25,10 @@ interface ProviderState { } export class CoreProvider extends EventEmitter { - #contentScriptConnection: AbstractConnection; - #requestRateLimiter = new RequestRatelimiter([ - 'eth_requestAccounts', - 'avalanche_selectWallet', + #providerReadyPromise = new ProviderReadyPromise([ + InitializationStep.PROVIDER_STATE_LOADED, ]); - #providerReadyPromise = new ProviderReadyPromise(); + #chainagnosticProvider?: ChainAgnosticProvider; readonly info: ProviderInfo = { name: EVM_PROVIDER_INFO_NAME, @@ -72,41 +66,40 @@ export class CoreProvider extends EventEmitter { isUnlocked: () => Promise.resolve(this._isUnlocked), }; - constructor({ - connection, - maxListeners = 100, - }: { - connection: AbstractConnection; - maxListeners?: number; - }) { + constructor(maxListeners: number = 100) { super(); this.setMaxListeners(maxListeners); - this.#contentScriptConnection = connection; - this.#init(); + this.#subscribe(); } + #subscribe() { + window.addEventListener( + EventNames.CORE_WALLET_ANNOUNCE_PROVIDER, + (event) => { + if (this.#chainagnosticProvider) { + return; + } + this.#chainagnosticProvider = (event).detail.provider; + + this.#chainagnosticProvider?.subscribeToMessage( + this.#handleBackgroundMessage + ); + + this.#init(); + } + ); + + window.dispatchEvent(new Event(EventNames.CORE_WALLET_REQUEST_PROVIDER)); + } /** * Initializes provider state, and collects dApp information */ #init = async () => { - await this.#contentScriptConnection.connect(); - this.#contentScriptConnection.on('message', this.#handleBackgroundMessage); - - onDomReady(async () => { - const domainMetadata = await getSiteMetadata(); - - this.#requestInternal({ - method: DAppProviderRequest.DOMAIN_METADATA_METHOD, - params: domainMetadata, - }); - - this.#providerReadyPromise.check(InitializationStep.DOMAIN_METADATA_SENT); - }); - try { - const response = await this.#requestInternal({ + const response = await this.#request({ method: DAppProviderRequest.INIT_DAPP_STATE, }); + const { chainId, accounts, networkVersion, isUnlocked } = (response as { isUnlocked: boolean; @@ -141,7 +134,6 @@ export class CoreProvider extends EventEmitter { InitializationStep.PROVIDER_STATE_LOADED ); } catch (e) { - console.error(e); // the provider will have a partial state, but still should be able to function } finally { this._initialized = true; @@ -154,40 +146,11 @@ export class CoreProvider extends EventEmitter { #request = async ( data: PartialBy ) => { - if (!data) { - throw ethErrors.rpc.invalidRequest(); - } - - return this.#contentScriptConnection - .request({ - method: 'provider_request', - jsonrpc: '2.0', - params: { - scope: `eip155:${ - this.chainId ? parseInt(this.chainId) : ChainId.AVALANCHE_MAINNET_ID - }`, - sessionId: this._sessionId, - request: { - params: [], - ...data, - }, - }, - } as JsonRpcRequest) - .catch((err) => { - // If the error is already a JsonRPCErorr do not serialize them. - // eth-rpc-errors always wraps errors if they have an unkown error code - // even if the code is valid like 4902 for unrecognized chain ID. - if (!!err.code && Number.isInteger(err.code) && !!err.message) { - throw err; - } - throw serializeError(err); - }); - }; - - #requestInternal = (data) => { - return this.#requestRateLimiter.call(data.method, () => - this.#request(data) - ); + return this.#chainagnosticProvider?.request({ + data, + chainId: this.chainId, + sessionId: this._sessionId, + }); }; #getEventHandler = (method: string): ((params: any) => void) => { @@ -216,9 +179,7 @@ export class CoreProvider extends EventEmitter { request = async (data: PartialBy) => { return this.#providerReadyPromise.call(() => { - return this.#requestRateLimiter.call(data.method, () => - this.#request(data) - ); + return this.#request(data); }); }; diff --git a/src/background/providers/MultiWalletProviderProxy.test.ts b/src/background/providers/MultiWalletProviderProxy.test.ts index e5fc80684..63b8820a4 100644 --- a/src/background/providers/MultiWalletProviderProxy.test.ts +++ b/src/background/providers/MultiWalletProviderProxy.test.ts @@ -4,7 +4,6 @@ import { MultiWalletProviderProxy, createMultiWalletProxy, } from './MultiWalletProviderProxy'; -import AutoPairingPostMessageConnection from '../utils/messaging/AutoPairingPostMessageConnection'; jest.mock('../utils/messaging/AutoPairingPostMessageConnection'); jest.mock('./CoreProvider', () => ({ @@ -16,11 +15,9 @@ jest.mock('./CoreProvider', () => ({ })); describe('src/background/providers/MultiWalletProviderProxy', () => { - const connectionMock = new AutoPairingPostMessageConnection(false); - describe('init', () => { it('initializes with the default provider', () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const mwpp = new MultiWalletProviderProxy(provider); @@ -46,7 +43,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { describe('addProvider', () => { it('adds new providers from coinbase proxy', () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const mwpp = new MultiWalletProviderProxy(provider); expect(mwpp.defaultProvider).toBe(provider); @@ -73,7 +70,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('does not add extra coinbase proxy', () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const mwpp = new MultiWalletProviderProxy(provider); expect(mwpp.defaultProvider).toBe(provider); @@ -90,7 +87,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('adds new provider', () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const mwpp = new MultiWalletProviderProxy(provider); expect(mwpp.defaultProvider).toBe(provider); @@ -107,7 +104,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { describe('wallet selection', () => { it('toggles wallet selection on `eth_requestAccounts` call if multiple providers', async () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const provider2 = { isMetaMask: true, request: jest.fn() }; const mwpp = new MultiWalletProviderProxy(provider); mwpp.addProvider(provider2); @@ -154,7 +151,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('does not toggle wallet selection if only core is available', async () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const mwpp = new MultiWalletProviderProxy(provider); provider.request = jest.fn().mockResolvedValueOnce(['0x000000']); @@ -177,7 +174,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('does not toggle wallet selection if wallet is already selected', async () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const provider2 = { isMetaMask: true, request: jest.fn() }; const mwpp = new MultiWalletProviderProxy(provider); mwpp.addProvider(provider2); @@ -233,7 +230,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('wallet selection works with legacy functions: enable', async () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const provider2 = { isMetaMask: true, enable: jest.fn() }; const mwpp = new MultiWalletProviderProxy(provider); mwpp.addProvider(provider2); @@ -275,7 +272,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('wallet selection works with legacy functions: sendAsync', async () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const provider2 = { isMetaMask: true, request: jest.fn() }; const mwpp = new MultiWalletProviderProxy(provider); mwpp.addProvider(provider2); @@ -329,7 +326,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('wallet selection works with legacy functions: send with callback', async () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const provider2 = { isMetaMask: true, request: jest.fn() }; const mwpp = new MultiWalletProviderProxy(provider); mwpp.addProvider(provider2); @@ -381,7 +378,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { describe('createMultiWalletProxy', () => { it('creates proxy with property deletion disabled', () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const mwpp = createMultiWalletProxy(provider); expect(mwpp.defaultProvider).toBe(provider); @@ -390,7 +387,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('allows setting extra params without changing the provider', () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const mwpp = createMultiWalletProxy(provider); (mwpp as any).somePromerty = true; @@ -414,7 +411,7 @@ describe('src/background/providers/MultiWalletProviderProxy', () => { }); it('maintains the providers list properly', () => { - const provider = new CoreProvider({ connection: connectionMock }); + const provider = new CoreProvider(); const mwpp = createMultiWalletProxy(provider); const fooMock = () => 'bar'; const bizMock = () => 'baz'; diff --git a/src/background/providers/initializeInpageProvider.test.ts b/src/background/providers/initializeInpageProvider.test.ts index abd66f007..259bfd796 100644 --- a/src/background/providers/initializeInpageProvider.test.ts +++ b/src/background/providers/initializeInpageProvider.test.ts @@ -3,7 +3,15 @@ import { CoreProvider } from './CoreProvider'; import { createMultiWalletProxy } from './MultiWalletProviderProxy'; import { initializeProvider } from './initializeInpageProvider'; -jest.mock('../utils/messaging/AutoPairingPostMessageConnection'); +jest.mock('../utils/messaging/AutoPairingPostMessageConnection', () => { + const mocks = { + connect: jest.fn().mockResolvedValue(undefined), + on: jest.fn(), + request: jest.fn().mockResolvedValue({}), + }; + return jest.fn().mockReturnValue(mocks); +}); + jest.mock('./CoreProvider', () => ({ CoreProvider: jest .fn() @@ -27,10 +35,7 @@ describe('src/background/providers/initializeInpageProvider', () => { it('initializes CoreProvider with the correct channel name', () => { const provider = initializeProvider(connectionMock, 10, windowMock); - expect(CoreProvider).toHaveBeenCalledWith({ - connection: connectionMock, - maxListeners: 10, - }); + expect(CoreProvider).toHaveBeenCalledWith(10); expect(provider.isAvalanche).toBe(true); }); @@ -176,7 +181,7 @@ describe('src/background/providers/initializeInpageProvider', () => { it('announces core provider with eip6963:announceProvider', () => { const provider = initializeProvider(connectionMock, 10, windowMock); - expect(windowMock.dispatchEvent).toHaveBeenCalledTimes(4); + expect(windowMock.dispatchEvent).toHaveBeenCalledTimes(5); expect(windowMock.dispatchEvent.mock.calls[3][0].type).toEqual( 'eip6963:announceProvider' ); @@ -186,11 +191,11 @@ describe('src/background/providers/initializeInpageProvider', () => { }); }); it('re-announces on eip6963:requestProvider', () => { - const provider = initializeProvider(connectionMock, 10, windowMock); + initializeProvider(connectionMock, 10, windowMock); - expect(windowMock.dispatchEvent).toHaveBeenCalledTimes(4); + expect(windowMock.dispatchEvent).toHaveBeenCalledTimes(5); - expect(windowMock.addEventListener).toHaveBeenCalledTimes(1); + expect(windowMock.addEventListener).toHaveBeenCalledTimes(2); expect(windowMock.addEventListener).toHaveBeenCalledWith( 'eip6963:requestProvider', expect.anything() @@ -198,15 +203,27 @@ describe('src/background/providers/initializeInpageProvider', () => { windowMock.addEventListener.mock.calls[0][1](); - expect(windowMock.dispatchEvent).toHaveBeenCalledTimes(5); + expect(windowMock.dispatchEvent).toHaveBeenCalledTimes(6); - expect(windowMock.dispatchEvent.mock.calls[4][0].type).toEqual( + expect(windowMock.dispatchEvent.mock.calls[3][0].type).toEqual( 'eip6963:announceProvider' ); - expect(windowMock.dispatchEvent.mock.calls[4][0].detail).toEqual({ - info: provider.info, - provider: provider, - }); + }); + }); + describe('core-wallet ', () => { + it('should announce chainagnostic provider with core-wallet:announceProvider', () => { + initializeProvider(connectionMock, 10, windowMock); + + expect(windowMock.dispatchEvent.mock.calls[4][0].type).toEqual( + 'core-wallet:announceProvider' + ); + }); + it('should re-announce on core-wallet:requestProvider', () => { + initializeProvider(connectionMock, 10, windowMock); + + expect(windowMock.dispatchEvent.mock.calls[4][0].type).toEqual( + 'core-wallet:announceProvider' + ); }); }); }); diff --git a/src/background/providers/initializeInpageProvider.ts b/src/background/providers/initializeInpageProvider.ts index 03964e0f5..f57466e6f 100644 --- a/src/background/providers/initializeInpageProvider.ts +++ b/src/background/providers/initializeInpageProvider.ts @@ -1,7 +1,8 @@ -import AbstractConnection from '../utils/messaging/AbstractConnection'; +import type AbstractConnection from '../utils/messaging/AbstractConnection'; +import { ChainAgnosticProvider } from './ChainAgnosticProvider'; import { CoreProvider } from './CoreProvider'; import { createMultiWalletProxy } from './MultiWalletProviderProxy'; -import { EIP6963ProviderDetail } from './models'; +import { EventNames, type EIP6963ProviderDetail } from './models'; /** * Initializes a CoreProvide and assigns it as window.ethereum. @@ -16,7 +17,14 @@ export function initializeProvider( maxListeners = 100, globalObject = window ): CoreProvider { - const provider = new Proxy(new CoreProvider({ connection, maxListeners }), { + const chainAgnosticProvider = new Proxy( + new ChainAgnosticProvider(connection), + { + deleteProperty: () => true, + } + ); + + const provider = new Proxy(new CoreProvider(maxListeners), { // some common libraries, e.g. web3@1.x, mess with our API deleteProperty: () => true, }); @@ -25,6 +33,7 @@ export function initializeProvider( setAvalancheGlobalProvider(provider, globalObject); setEvmproviders(provider, globalObject); announceWalletProvider(provider, globalObject); + announceChainAgnosticProvider(chainAgnosticProvider, globalObject); return provider; } @@ -119,7 +128,7 @@ function announceWalletProvider( globalObject = window ): void { const announceEvent = new CustomEvent( - 'eip6963:announceProvider', + EventNames.EIP6963_ANNOUNCE_PROVIDER, { detail: Object.freeze({ info: { ...providerInstance.info }, @@ -134,7 +143,31 @@ function announceWalletProvider( // The Wallet listens to the request events which may be // dispatched later and re-dispatches the `EIP6963AnnounceProviderEvent` - globalObject.addEventListener('eip6963:requestProvider', () => { + globalObject.addEventListener(EventNames.EIP6963_REQUEST_PROVIDER, () => { + globalObject.dispatchEvent(announceEvent); + }); +} + +function announceChainAgnosticProvider( + providerInstance: ChainAgnosticProvider, + globalObject = window +): void { + const announceEvent = new CustomEvent<{ provider: ChainAgnosticProvider }>( + EventNames.CORE_WALLET_ANNOUNCE_PROVIDER, + { + detail: Object.freeze({ + provider: providerInstance, + }), + } + ); + + // The Wallet dispatches an announce event which is heard by + // the DApp code that had run earlier + globalObject.dispatchEvent(announceEvent); + + // The Wallet listens to the request events which may be + // dispatched later and re-dispatches the `EIP6963AnnounceProviderEvent` + globalObject.addEventListener(EventNames.CORE_WALLET_REQUEST_PROVIDER, () => { globalObject.dispatchEvent(announceEvent); }); } diff --git a/src/background/providers/models.ts b/src/background/providers/models.ts index c952c5709..f338116ac 100644 --- a/src/background/providers/models.ts +++ b/src/background/providers/models.ts @@ -23,3 +23,10 @@ export interface EIP6963ProviderDetail { info: EIP6963ProviderInfo; provider: Eip1193Provider; } + +export enum EventNames { + CORE_WALLET_ANNOUNCE_PROVIDER = 'core-wallet:announceProvider', + CORE_WALLET_REQUEST_PROVIDER = 'core-wallet:requestProvider', + EIP6963_ANNOUNCE_PROVIDER = 'eip6963:announceProvider', + EIP6963_REQUEST_PROVIDER = 'eip6963:requestProvider', +} diff --git a/src/background/providers/utils/ProviderReadyPromise.test.ts b/src/background/providers/utils/ProviderReadyPromise.test.ts index d8071e715..e9550557f 100644 --- a/src/background/providers/utils/ProviderReadyPromise.test.ts +++ b/src/background/providers/utils/ProviderReadyPromise.test.ts @@ -5,7 +5,9 @@ import { describe('src/background/providers/utils/ProviderReadyPromise', () => { it('calls immediately if all checks are checked', async () => { - const initializedPromise = new ProviderReadyPromise(); + const initializedPromise = new ProviderReadyPromise([ + InitializationStep.DOMAIN_METADATA_SENT, + ]); initializedPromise.check(InitializationStep.DOMAIN_METADATA_SENT); initializedPromise.check(InitializationStep.PROVIDER_STATE_LOADED); @@ -17,7 +19,10 @@ describe('src/background/providers/utils/ProviderReadyPromise', () => { }); it('calls pending requests when last check is checked', async () => { - const initializedPromise = new ProviderReadyPromise(); + const initializedPromise = new ProviderReadyPromise([ + InitializationStep.DOMAIN_METADATA_SENT, + InitializationStep.PROVIDER_STATE_LOADED, + ]); initializedPromise.check(InitializationStep.DOMAIN_METADATA_SENT); const callMock = jest.fn(); @@ -31,7 +36,10 @@ describe('src/background/providers/utils/ProviderReadyPromise', () => { }); it('suspends calls when a check is unckecked', async () => { - const initializedPromise = new ProviderReadyPromise(); + const initializedPromise = new ProviderReadyPromise([ + InitializationStep.DOMAIN_METADATA_SENT, + InitializationStep.PROVIDER_STATE_LOADED, + ]); initializedPromise.check(InitializationStep.DOMAIN_METADATA_SENT); initializedPromise.check(InitializationStep.PROVIDER_STATE_LOADED); @@ -44,7 +52,6 @@ describe('src/background/providers/utils/ProviderReadyPromise', () => { initializedPromise.call(callMock); expect(callMock).toHaveBeenCalledTimes(1); - initializedPromise.check(InitializationStep.DOMAIN_METADATA_SENT); await new Promise(process.nextTick); diff --git a/src/background/providers/utils/ProviderReadyPromise.ts b/src/background/providers/utils/ProviderReadyPromise.ts index 4db3e72da..e877d73a3 100644 --- a/src/background/providers/utils/ProviderReadyPromise.ts +++ b/src/background/providers/utils/ProviderReadyPromise.ts @@ -1,32 +1,35 @@ export enum InitializationStep { - DOMAIN_METADATA_SENT, - PROVIDER_STATE_LOADED, + DOMAIN_METADATA_SENT = 'domain_metadata_sent', + PROVIDER_STATE_LOADED = 'provider_state_loaded', } export class ProviderReadyPromise { - #steps: boolean[] = []; + #unpreparedSteps: Map = new Map(); #inflightRequests: { resolve(value: unknown): void; - fn(): Promise; + fn(): Promise; }[] = []; - constructor() { - // length / 2 is required since InitializationStep is an enum - // enums generate objects like this: { key0: 0, key1: 1, 0: key0, 1: key1 } - this.#steps = Array(Object.keys(InitializationStep).length / 2).fill(false); + constructor(steps: InitializationStep[]) { + steps.map((step) => this.#unpreparedSteps.set(step, true)); } check = (step: InitializationStep) => { - this.#steps[step] = true; + const hasStep = this.#unpreparedSteps.has(step); + + if (hasStep) { + this.#unpreparedSteps.delete(step); + } + this._proceed(); }; uncheck = (step: InitializationStep) => { - this.#steps[step] = false; + this.#unpreparedSteps.set(step, true); }; private _proceed = () => { - if (this.#steps.some((step) => !step)) { + if (this.#unpreparedSteps.size) { return; } @@ -36,7 +39,7 @@ export class ProviderReadyPromise { } }; - call = (fn) => { + call = (fn: () => Promise) => { return new Promise((resolve) => { this.#inflightRequests.push({ fn,