diff --git a/shared/modules/shield.ts b/shared/modules/shield.ts index ff44fe0a5852..3305f89c2e54 100644 --- a/shared/modules/shield.ts +++ b/shared/modules/shield.ts @@ -1,4 +1,7 @@ -import { Subscription } from '@metamask/subscription-controller'; +import { + RECURRING_INTERVALS, + Subscription, +} from '@metamask/subscription-controller'; import { getIsShieldSubscriptionActive } from '../lib/shield'; export async function getShieldGatewayConfig( @@ -50,3 +53,34 @@ export async function getShieldGatewayConfig( }; } } + +/** + * Calculate the remaining billing cycles for a subscription + * + * @param params + * @param params.currentPeriodEnd - The current period end date. + * @param params.endDate - The end date. + * @param params.interval - The interval. + * @returns The remaining billing cycles. + */ +export function calculateSubscriptionRemainingBillingCycles({ + currentPeriodEnd, + endDate, + interval, +}: { + currentPeriodEnd: Date; + endDate: Date; + interval: (typeof RECURRING_INTERVALS)[keyof typeof RECURRING_INTERVALS]; +}): number { + if (interval === RECURRING_INTERVALS.month) { + const yearDiff = endDate.getFullYear() - currentPeriodEnd.getFullYear(); + const monthDiff = endDate.getMonth() - currentPeriodEnd.getMonth(); + // Assume the period end and endDate have the same day of the month and time + // Current period is inclusive, so we need to add 1 + return yearDiff * 12 + monthDiff + 1; + } + const yearDiff = endDate.getFullYear() - currentPeriodEnd.getFullYear(); + // Assume the period end and endDate have the same month, day of the month and time + // Current period is inclusive, so we need to add 1 + return yearDiff + 1; +} diff --git a/ui/contexts/shield/shield-subscription.tsx b/ui/contexts/shield/shield-subscription.tsx index 5de97330e3c3..50a1c3c0b4c2 100644 --- a/ui/contexts/shield/shield-subscription.tsx +++ b/ui/contexts/shield/shield-subscription.tsx @@ -18,7 +18,9 @@ import { getHasShieldEntryModalShownOnce, getIsActiveShieldSubscription, } from '../../selectors/subscription'; +import { MetaMaskReduxDispatch } from '../../store/store'; import { getIsUnlocked } from '../../ducks/metamask/metamask'; +import { useShieldAddFundTrigger } from './useAddFundTrigger'; export const ShieldSubscriptionContext = React.createContext<{ resetShieldEntryModalShownStatus: () => void; @@ -45,7 +47,7 @@ export const useShieldSubscriptionContext = () => { }; export const ShieldSubscriptionProvider: React.FC = ({ children }) => { - const dispatch = useDispatch(); + const dispatch = useDispatch(); const isBasicFunctionalityEnabled = Boolean( useSelector(getUseExternalServices), ); @@ -65,6 +67,9 @@ export const ShieldSubscriptionProvider: React.FC = ({ children }) => { true, // use USD conversion rate instead of the current currency ); + // watch handle add fund trigger server check subscirption paused because of insufficient funds + useShieldAddFundTrigger(); + /** * Check if the user's balance criteria is met to show the shield entry modal. * Shield entry modal will be shown if: diff --git a/ui/contexts/shield/useAddFundTrigger.ts b/ui/contexts/shield/useAddFundTrigger.ts new file mode 100644 index 000000000000..1b4aaa42ee57 --- /dev/null +++ b/ui/contexts/shield/useAddFundTrigger.ts @@ -0,0 +1,236 @@ +import { useCallback, useEffect, useMemo } from 'react'; +import { useDispatch, useSelector } from 'react-redux'; +import { + CRYPTO_PAYMENT_METHOD_ERRORS, + PAYMENT_TYPES, + PRODUCT_TYPES, + SUBSCRIPTION_STATUSES, + SubscriptionCryptoPaymentMethod, + SubscriptionStatus, +} from '@metamask/subscription-controller'; +import log from 'loglevel'; +import { useTokenBalances as pollAndUpdateEvmBalances } from '../../hooks/useTokenBalances'; +import { + useUserSubscriptionByProduct, + useUserSubscriptions, +} from '../../hooks/subscription/useSubscription'; +import { + getSubscriptions, + updateSubscriptionCryptoPaymentMethod, +} from '../../store/actions'; +import { getSelectedAccount } from '../../selectors'; +import { + useSubscriptionPaymentMethods, + useSubscriptionPricing, + useSubscriptionProductPlans, +} from '../../hooks/subscription/useSubscriptionPricing'; +import { isCryptoPaymentMethod } from '../../pages/settings/transaction-shield-tab/types'; +import { getTokenBalancesEvm } from '../../selectors/assets'; +import { MetaMaskReduxDispatch } from '../../store/store'; +import { calculateSubscriptionRemainingBillingCycles } from '../../../shared/modules/shield'; +import { useThrottle } from '../../hooks/useThrottle'; +import { MINUTE } from '../../../shared/constants/time'; + +const SHIELD_ADD_FUND_TRIGGER_INTERVAL = 5 * MINUTE; + +/** + * Trigger the subscription check after user funding met criteria + * + */ +export const useShieldAddFundTrigger = () => { + const dispatch = useDispatch(); + const { subscriptions } = useUserSubscriptions(); + const shieldSubscription = useUserSubscriptionByProduct( + PRODUCT_TYPES.SHIELD, + subscriptions, + ); + // TODO: update to correct subscription status after implementation + const isSubscriptionPaused = + shieldSubscription && + ( + [ + SUBSCRIPTION_STATUSES.paused, + SUBSCRIPTION_STATUSES.pastDue, + SUBSCRIPTION_STATUSES.unpaid, + ] as SubscriptionStatus[] + ).includes(shieldSubscription.status); + + const { subscriptionPricing } = useSubscriptionPricing(); + const pricingPlans = useSubscriptionProductPlans( + PRODUCT_TYPES.SHIELD, + subscriptionPricing, + ); + const cryptoPaymentMethod = useSubscriptionPaymentMethods( + PAYMENT_TYPES.byCrypto, + subscriptionPricing, + ); + + const cryptoPaymentInfo = shieldSubscription?.paymentMethod as + | SubscriptionCryptoPaymentMethod + | undefined; + const selectedTokenPrice = cryptoPaymentInfo + ? cryptoPaymentMethod?.chains + ?.find( + (chain) => + chain.chainId.toLowerCase() === + cryptoPaymentInfo?.crypto.chainId.toLowerCase(), + ) + ?.tokens.find( + (token) => + token.symbol.toLowerCase() === + cryptoPaymentInfo?.crypto.tokenSymbol.toLowerCase(), + ) + : undefined; + + const selectedProductPrice = useMemo(() => { + return pricingPlans?.find( + (plan) => plan.interval === shieldSubscription?.interval, + ); + }, [pricingPlans, shieldSubscription]); + + const paymentChainIds = useMemo( + () => (cryptoPaymentInfo ? [cryptoPaymentInfo.crypto.chainId] : []), + [cryptoPaymentInfo], + ); + + const selectedAccount = useSelector(getSelectedAccount); + const evmBalances = useSelector((state) => + getTokenBalancesEvm(state, selectedAccount?.address), + ); + + // Poll and update evm balances for payment chains + pollAndUpdateEvmBalances({ chainIds: paymentChainIds }); + // valid token balances for checking + const validTokenBalances = useMemo(() => { + return evmBalances.filter((token) => { + const supportedTokensForChain = + cryptoPaymentInfo?.crypto.chainId === token.chainId; + const isSupportedChain = Boolean(supportedTokensForChain); + if (!isSupportedChain) { + return false; + } + const isSupportedToken = + cryptoPaymentInfo?.crypto.tokenSymbol.toLowerCase() === + token.symbol.toLowerCase(); + if (!isSupportedToken) { + return false; + } + const hasBalance = token.balance && parseFloat(token.balance) > 0; + if (!hasBalance) { + return false; + } + if (!selectedProductPrice || !shieldSubscription?.endDate) { + return false; + } + + const remainingBillingCycles = + calculateSubscriptionRemainingBillingCycles({ + currentPeriodEnd: new Date(shieldSubscription.currentPeriodEnd), + endDate: new Date(shieldSubscription.endDate), + interval: shieldSubscription.interval, + }); + // no need to use BigInt since max unitDecimals are always 2 for price + const remainingFundBalanceNeeded = + (selectedProductPrice.unitAmount / + 10 ** selectedProductPrice.unitDecimals) * + remainingBillingCycles; + + return ( + token.balance && parseFloat(token.balance) >= remainingFundBalanceNeeded + ); + }); + }, [ + evmBalances, + cryptoPaymentInfo, + selectedProductPrice, + shieldSubscription, + ]); + + const hasAvailableSelectedToken = validTokenBalances.length > 0; + + // throttle the hasAvailableSelectedToken to avoid multiple triggers + const { value: hasAvailableSelectedTokenThrottled } = useThrottle({ + value: hasAvailableSelectedToken, + interval: SHIELD_ADD_FUND_TRIGGER_INTERVAL, + }); + + const handleTriggerSubscriptionCheck = useCallback(async () => { + if ( + !shieldSubscription || + !selectedProductPrice || + !hasAvailableSelectedTokenThrottled || + !cryptoPaymentInfo + ) { + return; + } + + try { + // selected token is available, so we can trigger the subscription check + await dispatch( + updateSubscriptionCryptoPaymentMethod({ + subscriptionId: shieldSubscription.id, + paymentType: PAYMENT_TYPES.byCrypto, + recurringInterval: shieldSubscription.interval, + chainId: cryptoPaymentInfo.crypto.chainId, + payerAddress: cryptoPaymentInfo.crypto.payerAddress, + tokenSymbol: cryptoPaymentInfo.crypto.tokenSymbol, + billingCycles: + shieldSubscription.billingCycles ?? + selectedProductPrice?.minBillingCycles, + rawTransaction: undefined, // no raw transaction to trigger server to check for new funded balance + }), + ); + // refetch subscription after trigger subscription check for new status + await dispatch(getSubscriptions()); + } catch (error) { + log.error( + '[useShieldAddFundTrigger] error triggering subscription check', + error, + ); + } + }, [ + dispatch, + shieldSubscription, + selectedProductPrice, + hasAvailableSelectedTokenThrottled, + cryptoPaymentInfo, + ]); + + useEffect(() => { + if ( + !shieldSubscription || + !isSubscriptionPaused || + !subscriptionPricing || + !cryptoPaymentInfo || + !selectedProductPrice + ) { + return; + } + const isInsufficientBalanceError = + cryptoPaymentInfo.crypto.error === + CRYPTO_PAYMENT_METHOD_ERRORS.INSUFFICIENT_BALANCE; + + const isCryptoPayment = isCryptoPaymentMethod( + shieldSubscription.paymentMethod, + ); + if ( + !isInsufficientBalanceError || + !isCryptoPayment || + !selectedTokenPrice || + !hasAvailableSelectedTokenThrottled + ) { + return; + } + + handleTriggerSubscriptionCheck(); + }, [ + isSubscriptionPaused, + subscriptionPricing, + cryptoPaymentInfo, + selectedTokenPrice, + selectedProductPrice, + hasAvailableSelectedTokenThrottled, + shieldSubscription, + handleTriggerSubscriptionCheck, + ]); +}; diff --git a/ui/hooks/subscription/useSubscriptionPricing.ts b/ui/hooks/subscription/useSubscriptionPricing.ts index 18d6d9b4352e..09bc21706afb 100644 --- a/ui/hooks/subscription/useSubscriptionPricing.ts +++ b/ui/hooks/subscription/useSubscriptionPricing.ts @@ -43,6 +43,15 @@ export type TokenWithApprovalAmount = ( }; }; +/** + * get user available token balances for starting subscription + * + * @param params + * @param params.paymentChains - The payment chains info. + * @param params.price - The product price. + * @param params.productType - The product type. + * @returns The available token balances. + */ export const useAvailableTokenBalances = (params: { paymentChains?: ChainPaymentInfo[]; price?: ProductPrice; diff --git a/ui/hooks/useThrottle.test.ts b/ui/hooks/useThrottle.test.ts new file mode 100644 index 000000000000..e4ee8d1e9482 --- /dev/null +++ b/ui/hooks/useThrottle.test.ts @@ -0,0 +1,180 @@ +import { renderHook } from '@testing-library/react-hooks'; +import { act } from '@testing-library/react'; +import { useThrottle } from './useThrottle'; + +describe('useThrottle', () => { + beforeEach(() => { + // Mock timers for testing + jest.useFakeTimers(); + jest.clearAllTimers(); + }); + + afterEach(() => { + jest.runOnlyPendingTimers(); + jest.useRealTimers(); + }); + + it('should return the initial value immediately', () => { + const { result } = renderHook(() => useThrottle('initial', 1000)); + expect(result.current).toBe('initial'); + }); + + it('should throttle value updates', () => { + const { result, rerender } = renderHook( + ({ value, limit }: { value: string; limit: number }) => + useThrottle(value, limit), + { + initialProps: { value: 'initial', limit: 1000 }, + }, + ); + + // Change value immediately + rerender({ value: 'updated', limit: 1000 }); + expect(result.current).toBe('initial'); // Should still be initial + + // Fast forward time by 500ms (less than limit) + act(() => { + jest.advanceTimersByTime(500); + }); + expect(result.current).toBe('initial'); // Should still be initial + + // Fast forward time by another 500ms (total 1000ms) + act(() => { + jest.advanceTimersByTime(500); + }); + expect(result.current).toBe('updated'); // Should now be updated + }); + + it('should handle multiple rapid value changes', () => { + const { result, rerender } = renderHook( + ({ value, limit }: { value: string; limit: number }) => + useThrottle(value, limit), + { + initialProps: { value: 'initial', limit: 1000 }, + }, + ); + + // Make multiple rapid changes + rerender({ value: 'first', limit: 1000 }); + rerender({ value: 'second', limit: 1000 }); + rerender({ value: 'third', limit: 1000 }); + + expect(result.current).toBe('initial'); // Should still be initial + + // Advance time by 1000ms + act(() => { + jest.advanceTimersByTime(1000); + }); + + expect(result.current).toBe('third'); // Should be the last value + }); + + it('should work with different data types', () => { + const { result, rerender } = renderHook( + ({ value, limit }: { value: number; limit: number }) => + useThrottle(value, limit), + { + initialProps: { value: 0, limit: 500 }, + }, + ); + + rerender({ value: 42, limit: 500 }); + expect(result.current).toBe(0); + + act(() => { + jest.advanceTimersByTime(500); + }); + expect(result.current).toBe(42); + }); + + it('should work with objects', () => { + const initialObj = { name: 'initial' }; + const updatedObj = { name: 'updated' }; + + const { result, rerender } = renderHook( + ({ value, limit }: { value: { name: string }; limit: number }) => + useThrottle(value, limit), + { + initialProps: { value: initialObj, limit: 1000 }, + }, + ); + + rerender({ value: updatedObj, limit: 1000 }); + expect(result.current).toBe(initialObj); + + act(() => { + jest.advanceTimersByTime(1000); + }); + expect(result.current).toBe(updatedObj); + }); + + it('should handle zero limit', () => { + const { result, rerender } = renderHook( + ({ value, limit }: { value: string; limit: number }) => + useThrottle(value, limit), + { + initialProps: { value: 'initial', limit: 0 }, + }, + ); + + rerender({ value: 'updated', limit: 0 }); + + act(() => { + jest.advanceTimersByTime(0); + }); + + expect(result.current).toBe('updated'); + }); + + it('should clean up timeouts on unmount', () => { + const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout'); + + const { unmount } = renderHook(() => useThrottle('test', 1000)); + + unmount(); + + expect(clearTimeoutSpy).toHaveBeenCalled(); + clearTimeoutSpy.mockRestore(); + }); + + it('should handle changing limit values', () => { + const { result, rerender } = renderHook( + ({ value, limit }: { value: string; limit: number }) => + useThrottle(value, limit), + { + initialProps: { value: 'initial', limit: 2000 }, + }, + ); + + // Change value with long limit + rerender({ value: 'updated', limit: 2000 }); + expect(result.current).toBe('initial'); + + // Change to shorter limit + rerender({ value: 'updated', limit: 500 }); + + act(() => { + jest.advanceTimersByTime(600); + }); + + expect(result.current).toBe('updated'); + }); + + it('should handle null and undefined values', () => { + const { result, rerender } = renderHook( + ({ value, limit }: { value: null | undefined; limit: number }) => + useThrottle(value, limit), + { + initialProps: { value: null as null | undefined, limit: 1000 }, + }, + ); + + rerender({ value: undefined, limit: 1000 }); + expect(result.current).toBe(null); + + act(() => { + jest.advanceTimersByTime(1000); + }); + expect(result.current).toBe(undefined); + }); +}); diff --git a/ui/hooks/useThrottle.ts b/ui/hooks/useThrottle.ts new file mode 100644 index 000000000000..ce139d6dd752 --- /dev/null +++ b/ui/hooks/useThrottle.ts @@ -0,0 +1,26 @@ +import { useState, useEffect, useRef } from 'react'; + +export function useThrottle(value: ValueType, interval = 500) { + const [throttledValue, setThrottledValue] = useState(value); + const lastUpdated = useRef(null); + + useEffect(() => { + const now = Date.now(); + + if (lastUpdated.current && now >= lastUpdated.current + interval) { + lastUpdated.current = now; + setThrottledValue(value); + return; + } + + const id = window.setTimeout(() => { + lastUpdated.current = now; + setThrottledValue(value); + }, interval); + + // eslint-disable-next-line consistent-return + return () => window.clearTimeout(id); + }, [value, interval]); + + return throttledValue; +}