diff --git a/apps/desktop/src/components/onboarding/final.tsx b/apps/desktop/src/components/onboarding/final.tsx
index 4558e0b368..ab66919c66 100644
--- a/apps/desktop/src/components/onboarding/final.tsx
+++ b/apps/desktop/src/components/onboarding/final.tsx
@@ -14,6 +14,7 @@ import { Route } from "../../routes/app/onboarding/_layout.index";
import * as settings from "../../store/tinybase/store/settings";
import { commands } from "../../types/tauri.gen";
import { configureProSettings } from "../../utils";
+import { pollForTrialActivation } from "../../utils/poll-trial-activation";
import { getBack, type StepProps } from "./config";
import { OnboardingContainer } from "./shared";
@@ -26,6 +27,10 @@ export function Final({ onNavigate }: StepProps) {
const [isLoading, setIsLoading] = useState(true);
const [trialStarted, setTrialStarted] = useState(false);
const hasHandledRef = useRef(false);
+ const authRef = useRef(auth);
+ authRef.current = auth;
+ const storeRef = useRef(store);
+ storeRef.current = store;
const backStep = getBack(search);
@@ -35,25 +40,33 @@ export function Final({ onNavigate }: StepProps) {
}
hasHandledRef.current = true;
+ const abortController = new AbortController();
+
const handle = async () => {
- if (!auth?.session) {
+ const currentAuth = authRef.current;
+ if (!currentAuth?.session) {
setIsLoading(false);
return;
}
- const headers = auth.getHeaders();
+ const headers = currentAuth.getHeaders();
if (!headers) {
setIsLoading(false);
return;
}
try {
- const started = await tryStartTrial(headers, store);
+ const started = await tryStartTrial(headers, storeRef.current);
setTrialStarted(started);
if (started) {
- await new Promise((resolve) => setTimeout(resolve, 3000));
+ const result = await pollForTrialActivation({
+ refreshSession: () => authRef.current.refreshSession(),
+ signal: abortController.signal,
+ });
+ if (result.status === "aborted") return;
+ } else {
+ await authRef.current.refreshSession();
}
- await auth.refreshSession();
} catch (e) {
Sentry.captureException(e);
console.error(e);
@@ -63,7 +76,12 @@ export function Final({ onNavigate }: StepProps) {
};
void handle();
- }, [auth, store]);
+
+ return () => {
+ abortController.abort();
+ };
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, []);
if (isLoading) {
return (
diff --git a/apps/desktop/src/components/settings/general/account.tsx b/apps/desktop/src/components/settings/general/account.tsx
index f85865624d..230b76efca 100644
--- a/apps/desktop/src/components/settings/general/account.tsx
+++ b/apps/desktop/src/components/settings/general/account.tsx
@@ -1,4 +1,4 @@
-import { useMutation, useQuery } from "@tanstack/react-query";
+import { useQuery } from "@tanstack/react-query";
import {
Brain,
Cloud,
@@ -9,7 +9,7 @@ import {
} from "lucide-react";
import { type ReactNode, useCallback, useEffect, useState } from "react";
-import { getRpcCanStartTrial, postBillingStartTrial } from "@hypr/api-client";
+import { getRpcCanStartTrial } from "@hypr/api-client";
import { createClient } from "@hypr/api-client/client";
import { commands as analyticsCommands } from "@hypr/plugin-analytics";
import { type SubscriptionStatus } from "@hypr/plugin-auth";
@@ -22,6 +22,7 @@ import { cn } from "@hypr/utils";
import { useAuth } from "../../../auth";
import { useBillingAccess } from "../../../billing";
import { env } from "../../../env";
+import { useTrialActivation } from "../../../hooks/useTrialActivation";
const WEB_APP_BASE_URL = env.VITE_APP_URL ?? "http://localhost:3000";
@@ -333,41 +334,7 @@ function BillingButton() {
},
});
- const startTrialMutation = useMutation({
- mutationFn: async () => {
- const headers = auth?.getHeaders();
- if (!headers) {
- throw new Error("Not authenticated");
- }
- const client = createClient({ baseUrl: env.VITE_API_URL, headers });
- const { error } = await postBillingStartTrial({
- client,
- query: { interval: "monthly" },
- });
- if (error) {
- throw error;
- }
-
- await new Promise((resolve) => setTimeout(resolve, 3000));
- },
- onSuccess: async () => {
- void analyticsCommands.event({
- event: "trial_started",
- plan: "pro",
- });
- const trialEndDate = new Date();
- trialEndDate.setDate(trialEndDate.getDate() + 14);
- void analyticsCommands.setProperties({
- email: auth?.session?.user.email,
- user_id: auth?.session?.user.id,
- set: {
- plan: "pro",
- trial_end_date: trialEndDate.toISOString(),
- },
- });
- await auth?.refreshSession();
- },
- });
+ const { startTrial, isPending: isTrialPending } = useTrialActivation();
const handleProUpgrade = useCallback(() => {
void analyticsCommands.event({
@@ -401,8 +368,8 @@ function BillingButton() {
return (
diff --git a/apps/desktop/src/hooks/useTrialActivation.ts b/apps/desktop/src/hooks/useTrialActivation.ts
new file mode 100644
index 0000000000..8a250b84bb
--- /dev/null
+++ b/apps/desktop/src/hooks/useTrialActivation.ts
@@ -0,0 +1,94 @@
+import { useMutation } from "@tanstack/react-query";
+import { useCallback, useEffect, useRef } from "react";
+
+import { postBillingStartTrial } from "@hypr/api-client";
+import { createClient } from "@hypr/api-client/client";
+import { commands as analyticsCommands } from "@hypr/plugin-analytics";
+
+import { useAuth } from "../auth";
+import { env } from "../env";
+import {
+ pollForTrialActivation,
+ type PollResult,
+} from "../utils/poll-trial-activation";
+
+type UseTrialActivationOptions = {
+ onActivated?: () => void;
+ onTimeout?: () => void;
+ onError?: (error: unknown) => void;
+};
+
+export function useTrialActivation(options: UseTrialActivationOptions = {}) {
+ const auth = useAuth();
+ const abortControllerRef = useRef(null);
+
+ useEffect(() => {
+ return () => {
+ abortControllerRef.current?.abort();
+ };
+ }, []);
+
+ const mutation = useMutation({
+ mutationFn: async (): Promise => {
+ const headers = auth?.getHeaders();
+ if (!headers) {
+ throw new Error("Not authenticated");
+ }
+
+ const client = createClient({ baseUrl: env.VITE_API_URL, headers });
+ const { error } = await postBillingStartTrial({
+ client,
+ query: { interval: "monthly" },
+ });
+ if (error) {
+ throw error;
+ }
+
+ abortControllerRef.current?.abort();
+ const abortController = new AbortController();
+ abortControllerRef.current = abortController;
+
+ return pollForTrialActivation({
+ refreshSession: () => auth.refreshSession(),
+ signal: abortController.signal,
+ });
+ },
+ onSuccess: (result) => {
+ if (result.status === "activated" || result.status === "timeout") {
+ void analyticsCommands.event({ event: "trial_started", plan: "pro" });
+ const trialEndDate = new Date();
+ trialEndDate.setDate(trialEndDate.getDate() + 14);
+ void analyticsCommands.setProperties({
+ email: auth?.session?.user.email,
+ user_id: auth?.session?.user.id,
+ set: {
+ plan: "pro",
+ trial_end_date: trialEndDate.toISOString(),
+ },
+ });
+ if (result.status === "activated") {
+ options.onActivated?.();
+ } else {
+ options.onTimeout?.();
+ }
+ }
+ },
+ onError: (error) => {
+ options.onError?.(error);
+ },
+ });
+
+ const cancel = useCallback(() => {
+ abortControllerRef.current?.abort();
+ abortControllerRef.current = null;
+ }, []);
+
+ return {
+ startTrial: mutation.mutate,
+ startTrialAsync: mutation.mutateAsync,
+ isPending: mutation.isPending,
+ isError: mutation.isError,
+ error: mutation.error,
+ cancel,
+ };
+}
diff --git a/apps/desktop/src/utils/poll-trial-activation.ts b/apps/desktop/src/utils/poll-trial-activation.ts
new file mode 100644
index 0000000000..d82c9d1289
--- /dev/null
+++ b/apps/desktop/src/utils/poll-trial-activation.ts
@@ -0,0 +1,74 @@
+import type { Session } from "@supabase/supabase-js";
+
+import { commands as authCommands } from "@hypr/plugin-auth";
+
+const INITIAL_DELAY_MS = 1000;
+const MAX_DELAY_MS = 5000;
+const BACKOFF_FACTOR = 1.5;
+const MAX_ATTEMPTS = 10;
+
+export type PollResult =
+ | { status: "activated"; session: Session }
+ | { status: "timeout" }
+ | { status: "aborted" };
+
+type PollOptions = {
+ refreshSession: () => Promise;
+ signal?: AbortSignal;
+};
+
+export async function pollForTrialActivation(
+ options: PollOptions,
+): Promise {
+ let delay = INITIAL_DELAY_MS;
+
+ for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) {
+ if (options.signal?.aborted) {
+ return { status: "aborted" };
+ }
+
+ try {
+ await new Promise((resolve, reject) => {
+ const timer = setTimeout(resolve, delay);
+ if (options.signal) {
+ const onAbort = () => {
+ clearTimeout(timer);
+ reject(new DOMException("Aborted", "AbortError"));
+ };
+ options.signal.addEventListener("abort", onAbort, { once: true });
+ }
+ });
+ } catch (e) {
+ if (e instanceof DOMException && e.name === "AbortError") {
+ return { status: "aborted" };
+ }
+ throw e;
+ }
+
+ if (options.signal?.aborted) {
+ return { status: "aborted" };
+ }
+
+ try {
+ const session = await options.refreshSession();
+ if (session) {
+ const result = await authCommands.decodeClaims(session.access_token);
+ if (result.status === "ok") {
+ const entitlements = result.data.entitlements ?? [];
+ if (entitlements.includes("hyprnote_pro")) {
+ return { status: "activated", session };
+ }
+ }
+ }
+ } catch (error) {
+ console.warn(
+ `Trial activation poll attempt ${attempt + 1} failed:`,
+ error,
+ );
+ }
+
+ delay = Math.min(delay * BACKOFF_FACTOR, MAX_DELAY_MS);
+ }
+
+ return { status: "timeout" };
+}