diff --git a/apps/desktop/src/components/onboarding/final.tsx b/apps/desktop/src/components/onboarding/final.tsx index 4558e0b368..ab66919c66 100644 --- a/apps/desktop/src/components/onboarding/final.tsx +++ b/apps/desktop/src/components/onboarding/final.tsx @@ -14,6 +14,7 @@ import { Route } from "../../routes/app/onboarding/_layout.index"; import * as settings from "../../store/tinybase/store/settings"; import { commands } from "../../types/tauri.gen"; import { configureProSettings } from "../../utils"; +import { pollForTrialActivation } from "../../utils/poll-trial-activation"; import { getBack, type StepProps } from "./config"; import { OnboardingContainer } from "./shared"; @@ -26,6 +27,10 @@ export function Final({ onNavigate }: StepProps) { const [isLoading, setIsLoading] = useState(true); const [trialStarted, setTrialStarted] = useState(false); const hasHandledRef = useRef(false); + const authRef = useRef(auth); + authRef.current = auth; + const storeRef = useRef(store); + storeRef.current = store; const backStep = getBack(search); @@ -35,25 +40,33 @@ export function Final({ onNavigate }: StepProps) { } hasHandledRef.current = true; + const abortController = new AbortController(); + const handle = async () => { - if (!auth?.session) { + const currentAuth = authRef.current; + if (!currentAuth?.session) { setIsLoading(false); return; } - const headers = auth.getHeaders(); + const headers = currentAuth.getHeaders(); if (!headers) { setIsLoading(false); return; } try { - const started = await tryStartTrial(headers, store); + const started = await tryStartTrial(headers, storeRef.current); setTrialStarted(started); if (started) { - await new Promise((resolve) => setTimeout(resolve, 3000)); + const result = await pollForTrialActivation({ + refreshSession: () => authRef.current.refreshSession(), + signal: abortController.signal, + }); + if (result.status === "aborted") return; + } else { + await authRef.current.refreshSession(); } - await auth.refreshSession(); } catch (e) { Sentry.captureException(e); console.error(e); @@ -63,7 +76,12 @@ export function Final({ onNavigate }: StepProps) { }; void handle(); - }, [auth, store]); + + return () => { + abortController.abort(); + }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); if (isLoading) { return ( diff --git a/apps/desktop/src/components/settings/general/account.tsx b/apps/desktop/src/components/settings/general/account.tsx index f85865624d..230b76efca 100644 --- a/apps/desktop/src/components/settings/general/account.tsx +++ b/apps/desktop/src/components/settings/general/account.tsx @@ -1,4 +1,4 @@ -import { useMutation, useQuery } from "@tanstack/react-query"; +import { useQuery } from "@tanstack/react-query"; import { Brain, Cloud, @@ -9,7 +9,7 @@ import { } from "lucide-react"; import { type ReactNode, useCallback, useEffect, useState } from "react"; -import { getRpcCanStartTrial, postBillingStartTrial } from "@hypr/api-client"; +import { getRpcCanStartTrial } from "@hypr/api-client"; import { createClient } from "@hypr/api-client/client"; import { commands as analyticsCommands } from "@hypr/plugin-analytics"; import { type SubscriptionStatus } from "@hypr/plugin-auth"; @@ -22,6 +22,7 @@ import { cn } from "@hypr/utils"; import { useAuth } from "../../../auth"; import { useBillingAccess } from "../../../billing"; import { env } from "../../../env"; +import { useTrialActivation } from "../../../hooks/useTrialActivation"; const WEB_APP_BASE_URL = env.VITE_APP_URL ?? "http://localhost:3000"; @@ -333,41 +334,7 @@ function BillingButton() { }, }); - const startTrialMutation = useMutation({ - mutationFn: async () => { - const headers = auth?.getHeaders(); - if (!headers) { - throw new Error("Not authenticated"); - } - const client = createClient({ baseUrl: env.VITE_API_URL, headers }); - const { error } = await postBillingStartTrial({ - client, - query: { interval: "monthly" }, - }); - if (error) { - throw error; - } - - await new Promise((resolve) => setTimeout(resolve, 3000)); - }, - onSuccess: async () => { - void analyticsCommands.event({ - event: "trial_started", - plan: "pro", - }); - const trialEndDate = new Date(); - trialEndDate.setDate(trialEndDate.getDate() + 14); - void analyticsCommands.setProperties({ - email: auth?.session?.user.email, - user_id: auth?.session?.user.id, - set: { - plan: "pro", - trial_end_date: trialEndDate.toISOString(), - }, - }); - await auth?.refreshSession(); - }, - }); + const { startTrial, isPending: isTrialPending } = useTrialActivation(); const handleProUpgrade = useCallback(() => { void analyticsCommands.event({ @@ -401,8 +368,8 @@ function BillingButton() { return ( diff --git a/apps/desktop/src/hooks/useTrialActivation.ts b/apps/desktop/src/hooks/useTrialActivation.ts new file mode 100644 index 0000000000..8a250b84bb --- /dev/null +++ b/apps/desktop/src/hooks/useTrialActivation.ts @@ -0,0 +1,94 @@ +import { useMutation } from "@tanstack/react-query"; +import { useCallback, useEffect, useRef } from "react"; + +import { postBillingStartTrial } from "@hypr/api-client"; +import { createClient } from "@hypr/api-client/client"; +import { commands as analyticsCommands } from "@hypr/plugin-analytics"; + +import { useAuth } from "../auth"; +import { env } from "../env"; +import { + pollForTrialActivation, + type PollResult, +} from "../utils/poll-trial-activation"; + +type UseTrialActivationOptions = { + onActivated?: () => void; + onTimeout?: () => void; + onError?: (error: unknown) => void; +}; + +export function useTrialActivation(options: UseTrialActivationOptions = {}) { + const auth = useAuth(); + const abortControllerRef = useRef(null); + + useEffect(() => { + return () => { + abortControllerRef.current?.abort(); + }; + }, []); + + const mutation = useMutation({ + mutationFn: async (): Promise => { + const headers = auth?.getHeaders(); + if (!headers) { + throw new Error("Not authenticated"); + } + + const client = createClient({ baseUrl: env.VITE_API_URL, headers }); + const { error } = await postBillingStartTrial({ + client, + query: { interval: "monthly" }, + }); + if (error) { + throw error; + } + + abortControllerRef.current?.abort(); + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + return pollForTrialActivation({ + refreshSession: () => auth.refreshSession(), + signal: abortController.signal, + }); + }, + onSuccess: (result) => { + if (result.status === "activated" || result.status === "timeout") { + void analyticsCommands.event({ event: "trial_started", plan: "pro" }); + const trialEndDate = new Date(); + trialEndDate.setDate(trialEndDate.getDate() + 14); + void analyticsCommands.setProperties({ + email: auth?.session?.user.email, + user_id: auth?.session?.user.id, + set: { + plan: "pro", + trial_end_date: trialEndDate.toISOString(), + }, + }); + if (result.status === "activated") { + options.onActivated?.(); + } else { + options.onTimeout?.(); + } + } + }, + onError: (error) => { + options.onError?.(error); + }, + }); + + const cancel = useCallback(() => { + abortControllerRef.current?.abort(); + abortControllerRef.current = null; + }, []); + + return { + startTrial: mutation.mutate, + startTrialAsync: mutation.mutateAsync, + isPending: mutation.isPending, + isError: mutation.isError, + error: mutation.error, + cancel, + }; +} diff --git a/apps/desktop/src/utils/poll-trial-activation.ts b/apps/desktop/src/utils/poll-trial-activation.ts new file mode 100644 index 0000000000..d82c9d1289 --- /dev/null +++ b/apps/desktop/src/utils/poll-trial-activation.ts @@ -0,0 +1,74 @@ +import type { Session } from "@supabase/supabase-js"; + +import { commands as authCommands } from "@hypr/plugin-auth"; + +const INITIAL_DELAY_MS = 1000; +const MAX_DELAY_MS = 5000; +const BACKOFF_FACTOR = 1.5; +const MAX_ATTEMPTS = 10; + +export type PollResult = + | { status: "activated"; session: Session } + | { status: "timeout" } + | { status: "aborted" }; + +type PollOptions = { + refreshSession: () => Promise; + signal?: AbortSignal; +}; + +export async function pollForTrialActivation( + options: PollOptions, +): Promise { + let delay = INITIAL_DELAY_MS; + + for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { + if (options.signal?.aborted) { + return { status: "aborted" }; + } + + try { + await new Promise((resolve, reject) => { + const timer = setTimeout(resolve, delay); + if (options.signal) { + const onAbort = () => { + clearTimeout(timer); + reject(new DOMException("Aborted", "AbortError")); + }; + options.signal.addEventListener("abort", onAbort, { once: true }); + } + }); + } catch (e) { + if (e instanceof DOMException && e.name === "AbortError") { + return { status: "aborted" }; + } + throw e; + } + + if (options.signal?.aborted) { + return { status: "aborted" }; + } + + try { + const session = await options.refreshSession(); + if (session) { + const result = await authCommands.decodeClaims(session.access_token); + if (result.status === "ok") { + const entitlements = result.data.entitlements ?? []; + if (entitlements.includes("hyprnote_pro")) { + return { status: "activated", session }; + } + } + } + } catch (error) { + console.warn( + `Trial activation poll attempt ${attempt + 1} failed:`, + error, + ); + } + + delay = Math.min(delay * BACKOFF_FACTOR, MAX_DELAY_MS); + } + + return { status: "timeout" }; +}