diff --git a/apps/mesh/package.json b/apps/mesh/package.json index fb8729be0e..92117f21d0 100644 --- a/apps/mesh/package.json +++ b/apps/mesh/package.json @@ -45,8 +45,8 @@ "quickjs-emscripten-core": "^0.31.0" }, "devDependencies": { - "@ai-sdk/provider": "^3.0.0", - "@ai-sdk/react": "^3.0.1", + "@ai-sdk/provider": "^3.0.8", + "@ai-sdk/react": "^3.0.103", "@better-auth/sso": "1.4.1", "@daveyplate/better-auth-ui": "^3.2.7", "@deco/ui": "workspace:*", @@ -100,7 +100,7 @@ "@untitledui/icons": "^0.0.19", "@vercel/nft": "^1.1.1", "@vitejs/plugin-react": "^5.1.0", - "ai": "^6.0.1", + "ai": "^6.0.101", "babel-plugin-react-compiler": "^1.0.0", "better-auth": "1.4.5", "class-variance-authority": "^0.7.1", diff --git a/apps/mesh/src/api/app.ts b/apps/mesh/src/api/app.ts index 643412516a..4a2affe8fd 100644 --- a/apps/mesh/src/api/app.ts +++ b/apps/mesh/src/api/app.ts @@ -31,7 +31,7 @@ import { tracingMiddleware, } from "../observability"; import authRoutes from "./routes/auth"; -import decopilotRoutes from "./routes/decopilot"; +import { createDecopilotRoutes } from "./routes/decopilot"; import downstreamTokenRoutes from "./routes/downstream-token"; import virtualMcpRoutes from "./routes/virtual-mcp"; import oauthProxyRoutes, { @@ -49,10 +49,26 @@ import { runPluginStartupHooks, } from "../core/plugin-loader"; import { CredentialVault } from "../encryption/credential-vault"; +import { + LocalCancelBroadcast, + type CancelBroadcast, +} from "./routes/decopilot/cancel-broadcast"; +import { createNatsConnectionProvider } from "../nats/connection"; +import { NatsCancelBroadcast } from "./routes/decopilot/nats-cancel-broadcast"; +import { + NoOpStreamBuffer, + type StreamBuffer, +} from "./routes/decopilot/stream-buffer"; +import { NatsStreamBuffer } from "./routes/decopilot/nats-stream-buffer"; +import { RunRegistry } from "./routes/decopilot/run-registry"; +import { SqlThreadStorage } from "../storage/threads"; // Track current event bus instance for cleanup during HMR let currentEventBus: EventBus | null = null; +// Track decopilot strategy cleanup (abort active runs, stop strategies) during HMR +let currentDecopilotCleanup: (() => void) | null = null; + // ============================================================================ // Deco Store OAuth Helpers // ============================================================================ @@ -163,6 +179,21 @@ export async function createApp(options: CreateAppOptions = {}) { }); } + // Create shared NATS provider when NATS_URL is set (must init before event bus) + const natsUrl = process.env.NATS_URL; + let natsProvider = natsUrl ? createNatsConnectionProvider() : null; + if (natsProvider) { + try { + await natsProvider.init(natsUrl!); + } catch (err) { + console.warn( + "[NATS] Connection failed, falling back to local-only mode:", + err, + ); + natsProvider = null; + } + } + // Create event bus with a lazy context getter // The notify function needs a context, but the context needs the event bus // We resolve this by having notify create its own system context @@ -180,12 +211,51 @@ export async function createApp(options: CreateAppOptions = {}) { // Create notify function that uses the context factory // This is called by the worker to deliver events to subscribers // EventBus uses the full MeshDatabase (includes Pool for PostgreSQL) - eventBus = createEventBus(database); + eventBus = createEventBus(database, undefined, natsProvider); } // Track for cleanup during HMR currentEventBus = eventBus; + // Decopilot strategy cleanup on HMR / shutdown + if (currentDecopilotCleanup) currentDecopilotCleanup(); + const threadStorage = new SqlThreadStorage(database.db); + + const runRegistry = new RunRegistry(); + + const cancelBroadcast: CancelBroadcast = natsProvider + ? new NatsCancelBroadcast({ + getConnection: () => natsProvider!.getConnection(), + }) + : new LocalCancelBroadcast(); + + const streamBuffer: StreamBuffer = natsProvider + ? new NatsStreamBuffer({ + getConnection: () => natsProvider!.getConnection(), + getJetStream: () => natsProvider!.getJetStream(), + }) + : new NoOpStreamBuffer(); + + cancelBroadcast + .start((threadId) => runRegistry.cancelLocal(threadId)) + .catch((err) => { + console.error("[Decopilot] CancelBroadcast start failed:", err); + }); + streamBuffer.init().catch((err) => { + console.warn( + "[Decopilot] StreamBuffer init failed, attach/late-join disabled:", + err, + ); + }); + + currentDecopilotCleanup = () => { + runRegistry.stopAll(threadStorage); + runRegistry.dispose(); + cancelBroadcast.stop().catch(() => {}); + streamBuffer.teardown(); + natsProvider?.drain().catch(() => {}); + }; + const app = new Hono(); // ============================================================================ @@ -633,6 +703,11 @@ export async function createApp(options: CreateAppOptions = {}) { } }); + const decopilotRoutes = createDecopilotRoutes({ + cancelBroadcast, + streamBuffer, + runRegistry, + }); app.route("/api", decopilotRoutes); // OpenAI-compatible LLM API routes diff --git a/apps/mesh/src/api/routes/decopilot.ts b/apps/mesh/src/api/routes/decopilot.ts index 1fbc16b112..32a5650d02 100644 --- a/apps/mesh/src/api/routes/decopilot.ts +++ b/apps/mesh/src/api/routes/decopilot.ts @@ -5,5 +5,6 @@ * The actual implementation lives in ./decopilot/routes.ts */ -export { default } from "./decopilot/routes"; +export { createDecopilotRoutes } from "./decopilot/routes"; +export type { DecopilotDeps } from "./decopilot/routes"; export type { StreamRequest } from "./decopilot/schemas"; diff --git a/apps/mesh/src/api/routes/decopilot/built-in-tools/subtask.ts b/apps/mesh/src/api/routes/decopilot/built-in-tools/subtask.ts index 743acc1b2c..ee5671686d 100644 --- a/apps/mesh/src/api/routes/decopilot/built-in-tools/subtask.ts +++ b/apps/mesh/src/api/routes/decopilot/built-in-tools/subtask.ts @@ -171,6 +171,10 @@ export function createSubtaskTool( providerMetadata, }); }, + onAbort: () => { + console.error(`[subtask:${agent_id}] Aborted`); + mcpClient.close().catch(() => {}); + }, onError: (error) => { console.error(`[subtask:${agent_id}] Error`, error); }, diff --git a/apps/mesh/src/api/routes/decopilot/cancel-broadcast.test.ts b/apps/mesh/src/api/routes/decopilot/cancel-broadcast.test.ts new file mode 100644 index 0000000000..8d3e28b0ad --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/cancel-broadcast.test.ts @@ -0,0 +1,41 @@ +import { describe, it, expect } from "bun:test"; +import { LocalCancelBroadcast } from "./cancel-broadcast"; + +describe("LocalCancelBroadcast", () => { + it("start stores the onCancel callback", async () => { + const broadcast = new LocalCancelBroadcast(); + const cancelled: string[] = []; + + await broadcast.start((id) => cancelled.push(id)); + broadcast.broadcast("thread-1"); + + expect(cancelled).toEqual(["thread-1"]); + }); + + it("broadcast invokes callback for each call", async () => { + const broadcast = new LocalCancelBroadcast(); + const cancelled: string[] = []; + + await broadcast.start((id) => cancelled.push(id)); + broadcast.broadcast("a"); + broadcast.broadcast("b"); + + expect(cancelled).toEqual(["a", "b"]); + }); + + it("stop nulls the callback so broadcast is a no-op", async () => { + const broadcast = new LocalCancelBroadcast(); + const cancelled: string[] = []; + + await broadcast.start((id) => cancelled.push(id)); + await broadcast.stop(); + broadcast.broadcast("thread-1"); + + expect(cancelled).toHaveLength(0); + }); + + it("broadcast before start is a no-op (no throw)", () => { + const broadcast = new LocalCancelBroadcast(); + expect(() => broadcast.broadcast("thread-1")).not.toThrow(); + }); +}); diff --git a/apps/mesh/src/api/routes/decopilot/cancel-broadcast.ts b/apps/mesh/src/api/routes/decopilot/cancel-broadcast.ts new file mode 100644 index 0000000000..6bcc3c713f --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/cancel-broadcast.ts @@ -0,0 +1,38 @@ +/** + * Cancel Broadcast Interface + * + * Abstraction for how run cancellation is broadcast across pods. + * In single-process mode, cancel is local only. + * In multi-pod deployments, NATS pub/sub propagates cancellation. + * + * Mirrors the SSEBroadcastStrategy pattern from event-bus. + */ + +export interface CancelBroadcast { + /** Start listening for cancel broadcasts. When received, call onCancel locally. */ + start(onCancel: (threadId: string) => void): Promise; + /** Broadcast a cancellation to all pods (including local). */ + broadcast(threadId: string): void; + /** Stop listening and release resources. */ + stop(): Promise; +} + +/** + * Local-only cancel — cancel only works on the current process. + * Suitable for single-process deployments and when NATS is unavailable. + */ +export class LocalCancelBroadcast implements CancelBroadcast { + private onCancel: ((threadId: string) => void) | null = null; + + async start(onCancel: (threadId: string) => void): Promise { + this.onCancel = onCancel; + } + + broadcast(threadId: string): void { + this.onCancel?.(threadId); + } + + async stop(): Promise { + this.onCancel = null; + } +} diff --git a/apps/mesh/src/api/routes/decopilot/conversation.test.ts b/apps/mesh/src/api/routes/decopilot/conversation.test.ts index 19c0552796..de4f6b7a59 100644 --- a/apps/mesh/src/api/routes/decopilot/conversation.test.ts +++ b/apps/mesh/src/api/routes/decopilot/conversation.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect } from "bun:test"; -import { processConversation } from "./conversation"; +import { processConversation, denyPendingApprovals } from "./conversation"; import type { ChatMessage } from "./types"; describe("processConversation", () => { @@ -103,3 +103,176 @@ describe("processConversation", () => { }); }); }); + +describe("denyPendingApprovals", () => { + it("returns messages unchanged when no assistant messages have pending approvals", () => { + const messages: ChatMessage[] = [ + { + id: "m1", + role: "assistant", + parts: [{ type: "text", text: "Hello" }], + }, + ]; + + const result = denyPendingApprovals(messages); + expect(result).toEqual(messages); + }); + + it("returns non-assistant messages unchanged", () => { + const messages: ChatMessage[] = [ + { id: "m1", role: "user", parts: [{ type: "text", text: "Hi" }] }, + ]; + + const result = denyPendingApprovals(messages); + expect(result).toEqual(messages); + expect(result[0]).toBe(messages[0]); + }); + + it("converts approval-requested state to output-denied with approved: false", () => { + const messages = [ + { + id: "m1", + role: "assistant", + parts: [ + { + type: "tool-invocation", + toolCallId: "tc-1", + toolName: "do_thing", + state: "approval-requested", + approval: { type: "tool-call" }, + args: {}, + }, + ], + }, + ] as unknown as ChatMessage[]; + + const result = denyPendingApprovals(messages); + const part = result[0]!.parts[0] as Record; + + expect(part.state).toBe("output-denied"); + expect(part.approval).toEqual({ + type: "tool-call", + approved: false, + reason: "User sent a new message without approving this tool call.", + }); + }); + + it("leaves parts without approval field unchanged even if state is approval-requested", () => { + const messages = [ + { + id: "m1", + role: "assistant", + parts: [ + { + type: "tool-invocation", + toolCallId: "tc-1", + toolName: "do_thing", + state: "approval-requested", + args: {}, + }, + ], + }, + ] as unknown as ChatMessage[]; + + const result = denyPendingApprovals(messages); + const part = result[0]!.parts[0] as Record; + + expect(part.state).toBe("approval-requested"); + }); + + it("handles mixed parts (some pending, some already resolved)", () => { + const messages = [ + { + id: "m1", + role: "assistant", + parts: [ + { type: "text", text: "Let me do that" }, + { + type: "tool-invocation", + toolCallId: "tc-1", + toolName: "a", + state: "approval-requested", + approval: { type: "tool-call" }, + args: {}, + }, + { + type: "tool-invocation", + toolCallId: "tc-2", + toolName: "b", + state: "output-available", + args: {}, + output: { result: "ok" }, + }, + ], + }, + ] as unknown as ChatMessage[]; + + const result = denyPendingApprovals(messages); + const parts = result[0]!.parts as Record[]; + + expect((parts[0] as { text: string }).text).toBe("Let me do that"); + expect(parts[1]!.state).toBe("output-denied"); + expect((parts[1]!.approval as { approved: boolean }).approved).toBe(false); + expect(parts[2]!.state).toBe("output-available"); + }); + + it("denies pending approvals across multiple assistant messages", () => { + const messages = [ + { + id: "m1", + role: "assistant", + parts: [ + { + type: "tool-invocation", + toolCallId: "tc-1", + toolName: "older_tool", + state: "approval-requested", + approval: { type: "tool-call" }, + args: {}, + }, + ], + }, + { id: "m2", role: "user", parts: [{ type: "text", text: "continue" }] }, + { + id: "m3", + role: "assistant", + parts: [ + { + type: "tool-invocation", + toolCallId: "tc-2", + toolName: "newer_tool", + state: "approval-requested", + approval: { type: "tool-call" }, + args: {}, + }, + ], + }, + ] as unknown as ChatMessage[]; + + const result = denyPendingApprovals(messages); + + const olderPart = result[0]!.parts[0] as Record; + expect(olderPart.state).toBe("output-denied"); + expect((olderPart.approval as { approved: boolean }).approved).toBe(false); + + expect(result[1]).toBe(messages[1]); + + const newerPart = result[2]!.parts[0] as Record; + expect(newerPart.state).toBe("output-denied"); + expect((newerPart.approval as { approved: boolean }).approved).toBe(false); + }); + + it("returns same reference when no assistant messages need patching", () => { + const messages: ChatMessage[] = [ + { id: "m1", role: "user", parts: [{ type: "text", text: "Hi" }] }, + { + id: "m2", + role: "assistant", + parts: [{ type: "text", text: "Hello!" }], + }, + ]; + + const result = denyPendingApprovals(messages); + expect(result).toBe(messages); + }); +}); diff --git a/apps/mesh/src/api/routes/decopilot/conversation.ts b/apps/mesh/src/api/routes/decopilot/conversation.ts index 910209518d..f5f6d4bbba 100644 --- a/apps/mesh/src/api/routes/decopilot/conversation.ts +++ b/apps/mesh/src/api/routes/decopilot/conversation.ts @@ -36,14 +36,9 @@ export interface ProcessedConversation { originalMessages: ChatMessage[]; } -/** - * Marks any tool parts still in "approval-requested" state as "output-denied". - * This happens when the user sends a new message without approving/rejecting - * pending tool calls. convertToModelMessages then produces the correct - * assistant(tool-call) → tool(tool-result) pairing automatically. - */ -function denyPendingApprovals(messages: ChatMessage[]): ChatMessage[] { - return messages.map((msg) => { +export function denyPendingApprovals(messages: ChatMessage[]): ChatMessage[] { + let patched = false; + const result = messages.map((msg) => { if (msg.role !== "assistant") return msg; const hasPending = msg.parts.some( @@ -51,6 +46,7 @@ function denyPendingApprovals(messages: ChatMessage[]): ChatMessage[] { ); if (!hasPending) return msg; + patched = true; return { ...msg, parts: msg.parts.map((part) => { @@ -74,6 +70,8 @@ function denyPendingApprovals(messages: ChatMessage[]): ChatMessage[] { }), } as ChatMessage; }); + + return patched ? result : messages; } function splitMessages(messages: ModelMessage[]): { diff --git a/apps/mesh/src/api/routes/decopilot/helpers.ts b/apps/mesh/src/api/routes/decopilot/helpers.ts index 5320804f23..ac11e183ce 100644 --- a/apps/mesh/src/api/routes/decopilot/helpers.ts +++ b/apps/mesh/src/api/routes/decopilot/helpers.ts @@ -18,6 +18,7 @@ import { import type { Context } from "hono"; import type { MeshContext, OrganizationScope } from "@/core/mesh-context"; +import { HTTPException } from "hono/http-exception"; import { MCP_TOOL_CALL_TIMEOUT_MS } from "../proxy"; import { estimateJsonTokens } from "./built-in-tools/read-tool-output"; @@ -166,3 +167,30 @@ export async function toolsFromMCP( return Object.fromEntries(toolEntries); } + +/** + * Validate that the caller owns the thread and it belongs to the org. + * Reusable across cancel, attach, and other thread-scoped endpoints. + */ +export async function validateThreadOwnership( + c: Context<{ Variables: { meshContext: MeshContext } }>, +) { + const ctx = c.get("meshContext"); + const userId = ctx.auth?.user?.id; + if (!userId) { + throw new HTTPException(401, { message: "Unauthorized" }); + } + const organization = ensureOrganization(c); + const threadId = c.req.param("threadId"); + if (/[.*>\s]/.test(threadId)) { + throw new HTTPException(400, { message: "Invalid thread ID" }); + } + const thread = await ctx.storage.threads.get(threadId); + if (!thread || thread.organization_id !== organization.id) { + throw new HTTPException(404, { message: "Thread not found" }); + } + if (thread.created_by !== userId) { + throw new HTTPException(403, { message: "Not authorized" }); + } + return { ctx, organization, thread, threadId, userId }; +} diff --git a/apps/mesh/src/api/routes/decopilot/nats-cancel-broadcast.test.ts b/apps/mesh/src/api/routes/decopilot/nats-cancel-broadcast.test.ts new file mode 100644 index 0000000000..557f9c260b --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/nats-cancel-broadcast.test.ts @@ -0,0 +1,115 @@ +import { describe, it, expect, mock } from "bun:test"; +import { NatsCancelBroadcast } from "./nats-cancel-broadcast"; + +function createMockSubscription(messages: Array<{ data: Uint8Array }> = []) { + let unsubscribed = false; + return { + unsubscribe() { + unsubscribed = true; + }, + get isUnsubscribed() { + return unsubscribed; + }, + async *[Symbol.asyncIterator]() { + for (const msg of messages) { + if (unsubscribed) return; + yield msg; + } + }, + }; +} + +function createMockNatsConnection( + sub?: ReturnType, +) { + const published: Array<{ subject: string; data: Uint8Array }> = []; + return { + nc: { + subscribe: mock(() => sub ?? createMockSubscription()), + publish(subject: string, data: Uint8Array) { + published.push({ subject, data }); + }, + } as never, + published, + }; +} + +describe("NatsCancelBroadcast", () => { + it("start subscribes to cancel subject", async () => { + const { nc } = createMockNatsConnection(); + const broadcast = new NatsCancelBroadcast({ getConnection: () => nc }); + + await broadcast.start(() => {}); + // @ts-expect-error - nc.subscribe is not typed correctly + expect(nc.subscribe).toHaveBeenCalledTimes(1); + await broadcast.stop(); + }); + + it("broadcast calls local onCancel and publishes to NATS", async () => { + const { nc, published } = createMockNatsConnection(); + const broadcast = new NatsCancelBroadcast({ getConnection: () => nc }); + const cancelled: string[] = []; + + await broadcast.start((id) => cancelled.push(id)); + broadcast.broadcast("thread-1"); + + expect(cancelled).toEqual(["thread-1"]); + expect(published).toHaveLength(1); + const payload = JSON.parse( + new TextDecoder().decode(published[0]?.data ?? new Uint8Array()), + ); + expect(payload.threadId).toBe("thread-1"); + }); + + it("stop unsubscribes and nulls callback", async () => { + const sub = createMockSubscription(); + const { nc } = createMockNatsConnection(sub); + const broadcast = new NatsCancelBroadcast({ getConnection: () => nc }); + + await broadcast.start(() => {}); + await broadcast.stop(); + + expect(sub.isUnsubscribed).toBe(true); + }); + + it("subscription handler invokes onCancel for valid messages", async () => { + const encoder = new TextEncoder(); + const msg = { data: encoder.encode(JSON.stringify({ threadId: "t-abc" })) }; + const sub = createMockSubscription([msg]); + const { nc } = createMockNatsConnection(sub); + const broadcast = new NatsCancelBroadcast({ getConnection: () => nc }); + const cancelled: string[] = []; + + await broadcast.start((id) => cancelled.push(id)); + // Allow async iteration to process + await new Promise((r) => setTimeout(r, 50)); + await broadcast.stop(); + + expect(cancelled).toContain("t-abc"); + }); + + it("subscription handler ignores malformed messages", async () => { + const encoder = new TextEncoder(); + const msg = { data: encoder.encode("not json") }; + const sub = createMockSubscription([msg]); + const { nc } = createMockNatsConnection(sub); + const broadcast = new NatsCancelBroadcast({ getConnection: () => nc }); + const cancelled: string[] = []; + + await broadcast.start((id) => cancelled.push(id)); + await new Promise((r) => setTimeout(r, 50)); + await broadcast.stop(); + + expect(cancelled).toHaveLength(0); + }); + + it("broadcast is a no-op when NATS is unavailable", async () => { + const broadcast = new NatsCancelBroadcast({ getConnection: () => null }); + const cancelled: string[] = []; + + await broadcast.start((id) => cancelled.push(id)); + broadcast.broadcast("thread-1"); + + expect(cancelled).toEqual(["thread-1"]); + }); +}); diff --git a/apps/mesh/src/api/routes/decopilot/nats-cancel-broadcast.ts b/apps/mesh/src/api/routes/decopilot/nats-cancel-broadcast.ts new file mode 100644 index 0000000000..b719c7d9dc --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/nats-cancel-broadcast.ts @@ -0,0 +1,87 @@ +/** + * NATS Cancel Broadcast + * + * Broadcasts run cancellation across pods via NATS Core pub/sub. + * When a cancel is received from any pod, the local onCancel callback + * is invoked to abort the run if it exists on this pod. + * + * Cancel is inherently fire-and-forget — if the pod is gone, the run is gone. + * JetStream persistence would be wrong here (replaying stale cancels). + */ + +import type { NatsConnection, Subscription } from "nats"; +import type { CancelBroadcast } from "./cancel-broadcast"; + +const CANCEL_SUBJECT = "mesh.decopilot.cancel"; + +export interface NatsCancelBroadcastOptions { + getConnection: () => NatsConnection | null; +} + +export class NatsCancelBroadcast implements CancelBroadcast { + private sub: Subscription | null = null; + private onCancel: ((threadId: string) => void) | null = null; + private readonly encoder = new TextEncoder(); + private readonly originId = crypto.randomUUID(); + + constructor(private readonly options: NatsCancelBroadcastOptions) {} + + async start(onCancel: (threadId: string) => void): Promise { + this.onCancel = onCancel; + + const nc = this.options.getConnection(); + if (!nc || this.sub) return; + this.sub = nc.subscribe(CANCEL_SUBJECT); + + const decoder = new TextDecoder(); + + (async () => { + for await (const msg of this.sub!) { + try { + const parsed = JSON.parse(decoder.decode(msg.data)) as { + threadId: string; + originId?: string; + }; + if (parsed.originId === this.originId) continue; + this.onCancel?.(parsed.threadId); + } catch { + // Ignore malformed messages + } + } + })().catch(console.error); + + console.log("[NatsCancelBroadcast] Started, subscribed to", CANCEL_SUBJECT); + } + + broadcast(threadId: string): void { + if (/[.*>\s]/.test(threadId)) { + console.warn( + "[NatsCancelBroadcast] Invalid threadId, skipping broadcast", + ); + return; + } + + this.onCancel?.(threadId); + + const nc = this.options.getConnection(); + if (!nc) return; + + try { + nc.publish( + CANCEL_SUBJECT, + this.encoder.encode( + JSON.stringify({ threadId, originId: this.originId }), + ), + ); + } catch (err) { + console.warn("[NatsCancelBroadcast] Publish failed (non-critical):", err); + } + } + + async stop(): Promise { + this.sub?.unsubscribe(); + this.sub = null; + this.onCancel = null; + console.log("[NatsCancelBroadcast] Stopped"); + } +} diff --git a/apps/mesh/src/api/routes/decopilot/nats-stream-buffer.test.ts b/apps/mesh/src/api/routes/decopilot/nats-stream-buffer.test.ts new file mode 100644 index 0000000000..1fa7d5cbc9 --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/nats-stream-buffer.test.ts @@ -0,0 +1,137 @@ +import { describe, it, expect, mock } from "bun:test"; +import { NatsStreamBuffer } from "./nats-stream-buffer"; +import { NoOpStreamBuffer } from "./stream-buffer"; + +describe("NoOpStreamBuffer", () => { + it("relay returns the input stream unchanged", () => { + const buffer = new NoOpStreamBuffer(); + const stream = new ReadableStream(); + expect(buffer.relay(stream)).toBe(stream); + }); + + it("createReplayStream returns null", async () => { + const buffer = new NoOpStreamBuffer(); + expect(await buffer.createReplayStream()).toBeNull(); + }); + + it("purge and teardown are no-ops (no throw)", () => { + const buffer = new NoOpStreamBuffer(); + expect(() => buffer.purge()).not.toThrow(); + expect(() => buffer.teardown()).not.toThrow(); + }); +}); + +describe("NatsStreamBuffer", () => { + it("init is a no-op when getConnection returns null", async () => { + const buffer = new NatsStreamBuffer({ + getConnection: () => null, + getJetStream: () => null, + }); + await expect(buffer.init()).resolves.toBeUndefined(); + }); + + it("relay passes through when JetStream is unavailable", async () => { + const buffer = new NatsStreamBuffer({ + getConnection: () => null, + getJetStream: () => null, + }); + + const chunks = [{ type: "text", value: "hello" }]; + const input = new ReadableStream({ + start(controller) { + for (const c of chunks) controller.enqueue(c); + controller.close(); + }, + }); + + const output = buffer.relay(input, "thread-1"); + const reader = output.getReader(); + const result = await reader.read(); + + expect(result.value).toEqual(chunks[0]); + }); + + it("createReplayStream returns null when JetStream is unavailable", async () => { + const buffer = new NatsStreamBuffer({ + getConnection: () => null, + getJetStream: () => null, + }); + expect(await buffer.createReplayStream("any")).toBeNull(); + }); + + it("purge is a no-op when jsm is not initialized (no throw)", () => { + const buffer = new NatsStreamBuffer({ + getConnection: () => null, + getJetStream: () => null, + }); + expect(() => buffer.purge("any")).not.toThrow(); + }); + + it("teardown clears references", () => { + const buffer = new NatsStreamBuffer({ + getConnection: () => null, + getJetStream: () => null, + }); + expect(() => buffer.teardown()).not.toThrow(); + }); + + it("init creates or updates stream when connection is available", async () => { + const streamInfoMock = mock(() => Promise.resolve({})); + const streamUpdateMock = mock(() => Promise.resolve({})); + const streamAddMock = mock(() => Promise.resolve({})); + + const mockJsm = { + streams: { + info: streamInfoMock, + update: streamUpdateMock, + add: streamAddMock, + }, + }; + + const mockNc = { + jetstreamManager: mock(() => Promise.resolve(mockJsm)), + }; + + const mockJs = {} as never; + + const buffer = new NatsStreamBuffer({ + getConnection: () => mockNc as never, + getJetStream: () => mockJs, + }); + + await buffer.init(); + + expect(mockNc.jetstreamManager).toHaveBeenCalledTimes(1); + expect(streamInfoMock).toHaveBeenCalledWith("DECOPILOT_STREAMS"); + expect(streamUpdateMock).toHaveBeenCalledTimes(1); + }); + + it("init falls back to add when info throws", async () => { + const streamInfoMock = mock(() => + Promise.reject(new Error("stream not found")), + ); + const streamUpdateMock = mock(() => Promise.resolve({})); + const streamAddMock = mock(() => Promise.resolve({})); + + const mockJsm = { + streams: { + info: streamInfoMock, + update: streamUpdateMock, + add: streamAddMock, + }, + }; + + const mockNc = { + jetstreamManager: mock(() => Promise.resolve(mockJsm)), + }; + + const buffer = new NatsStreamBuffer({ + getConnection: () => mockNc as never, + getJetStream: () => ({}) as never, + }); + + await buffer.init(); + + expect(streamAddMock).toHaveBeenCalledTimes(1); + }); +}); diff --git a/apps/mesh/src/api/routes/decopilot/nats-stream-buffer.ts b/apps/mesh/src/api/routes/decopilot/nats-stream-buffer.ts new file mode 100644 index 0000000000..9c6355521c --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/nats-stream-buffer.ts @@ -0,0 +1,237 @@ +/** + * NATS JetStream Stream Buffer + * + * Publishes UIMessageStream chunks to NATS JetStream (memory storage) + * so late-joining clients can replay the stream from any pod. + * + * Enhancements over original jetstream-relay.ts: + * - Per-subject message limit (20K chunks per thread) prevents one thread from starving others + * - Per-thread publish error tracking with sampled logging + * - Explicit purge method for run completion cleanup + */ + +import { + AckPolicy, + DiscardPolicy, + RetentionPolicy, + StorageType, + type JetStreamClient, + type JetStreamManager, + type NatsConnection, +} from "nats"; +import type { StreamBuffer } from "./stream-buffer"; + +const STREAM_NAME = "DECOPILOT_STREAMS"; +const SUBJECT_PREFIX = "decopilot.stream"; +const MAX_AGE_NS = 5 * 60 * 1_000_000_000; // 5 min +const MAX_BYTES = 500 * 1024 * 1024; // 500 MB +const MAX_MSGS_PER_SUBJECT = 20_000; // ~20K chunks per thread +const PULL_TIMEOUT_MS = 30_000; + +function assertSafeSubjectToken(id: string): void { + if (/[.*>\s]/.test(id)) throw new Error("Invalid NATS subject token"); +} + +function streamSubject(threadId: string): string { + assertSafeSubjectToken(threadId); + return `${SUBJECT_PREFIX}.${threadId}`; +} + +function createPublishTracker(threadId: string) { + let errors = 0; + return { + publish(js: JetStreamClient, subj: string, data: Uint8Array): void { + js.publish(subj, data).catch((err) => { + errors++; + if (errors === 1 || errors % 100 === 0) { + console.warn( + `[Decopilot] JetStream publish failed for thread ${threadId} (${errors} total):`, + err, + ); + } + }); + }, + get errorCount() { + return errors; + }, + }; +} + +export interface NatsStreamBufferOptions { + getConnection: () => NatsConnection | null; + getJetStream: () => JetStreamClient | null; +} + +export class NatsStreamBuffer implements StreamBuffer { + private js: JetStreamClient | null = null; + private jsm: JetStreamManager | null = null; + private readonly encoder = new TextEncoder(); + + constructor(private readonly options: NatsStreamBufferOptions) {} + + async init(): Promise { + const nc = this.options.getConnection(); + if (!nc) return; + + const jsm = await nc.jetstreamManager(); + + const config = { + name: STREAM_NAME, + subjects: [`${SUBJECT_PREFIX}.>`], + storage: StorageType.Memory, + max_age: MAX_AGE_NS, + max_bytes: MAX_BYTES, + max_msgs_per_subject: MAX_MSGS_PER_SUBJECT, + discard: DiscardPolicy.Old, + retention: RetentionPolicy.Limits, + num_replicas: 1, + }; + + try { + await jsm.streams.info(STREAM_NAME); + await jsm.streams.update(STREAM_NAME, config); + } catch (err: unknown) { + const isNotFound = + err instanceof Error && err.message.includes("stream not found"); + if (isNotFound) { + await jsm.streams.add(config); + } else { + throw err; + } + } + + this.js = this.options.getJetStream(); + this.jsm = jsm; + console.log( + "[Decopilot] JetStream stream buffer ready (memory storage, 5min TTL, 20K msgs/subject)", + ); + } + + relay( + stream: ReadableStream, + threadId: string, + abortSignal?: AbortSignal, + ): ReadableStream { + const js = this.js; + if (!js) return stream; + + const subj = streamSubject(threadId); + const tracker = createPublishTracker(threadId); + const encoder = this.encoder; + let terminated = false; + + const publishDone = () => { + if (terminated) return; + terminated = true; + js.publish(subj, encoder.encode(JSON.stringify({ done: true }))).catch( + () => {}, + ); + }; + + abortSignal?.addEventListener("abort", publishDone); + + return stream.pipeThrough( + new TransformStream({ + transform(chunk, controller) { + controller.enqueue(chunk); + tracker.publish( + js, + subj, + encoder.encode(JSON.stringify({ p: chunk })), + ); + }, + flush() { + abortSignal?.removeEventListener("abort", publishDone); + publishDone(); + }, + }), + ); + } + + async createReplayStream(threadId: string): Promise { + const js = this.js; + if (!js) return null; + + const subj = streamSubject(threadId); + + let sub; + try { + sub = await js.subscribe(subj, { + ordered: true, + config: { + filter_subject: subj, + ack_policy: AckPolicy.None, + }, + }); + } catch (err) { + console.warn( + "[Decopilot] JetStream replay unavailable (non-critical):", + (err as Error)?.message ?? err, + ); + return null; + } + + const decoder = new TextDecoder(); + + // Use explicit iterator so pull() maintains position across invocations + const iter = (async function* () { + for await (const msg of sub) { + yield msg; + } + })(); + + return new ReadableStream({ + async pull(controller) { + while (true) { + let timer: ReturnType | undefined; + const result = await Promise.race([ + iter.next(), + new Promise<{ done: true; value: undefined }>((r) => { + timer = setTimeout( + () => r({ done: true, value: undefined }), + PULL_TIMEOUT_MS, + ); + }), + ]); + clearTimeout(timer); + if (result.done) { + sub.unsubscribe(); + controller.close(); + return; + } + const msg = result.value; + try { + const data = JSON.parse(decoder.decode(msg.data)); + if (data.done) { + sub.unsubscribe(); + controller.close(); + return; + } + if (data.p) { + controller.enqueue(data.p); + return; + } + } catch { + // skip malformed, continue to next message + } + } + }, + cancel() { + sub.unsubscribe(); + }, + }); + } + + purge(threadId: string): void { + if (!this.jsm) return; + this.jsm.streams + .purge(STREAM_NAME, { filter: streamSubject(threadId) }) + .catch(() => {}); + } + + teardown(): void { + this.js = null; + this.jsm = null; + console.log("[Decopilot] JetStream stream buffer torn down"); + } +} diff --git a/apps/mesh/src/api/routes/decopilot/routes.ts b/apps/mesh/src/api/routes/decopilot/routes.ts index b5620914f8..f98229526a 100644 --- a/apps/mesh/src/api/routes/decopilot/routes.ts +++ b/apps/mesh/src/api/routes/decopilot/routes.ts @@ -8,7 +8,13 @@ import type { MeshContext } from "@/core/mesh-context"; import { clientFromConnection, withStreamingSupport } from "@/mcp-clients"; import { createVirtualClientFrom } from "@/mcp-clients/virtual-mcp"; -import { sanitizeProviderMetadata } from "@decocms/mesh-sdk"; +import { + sanitizeProviderMetadata, + createDecopilotStepEvent, + createDecopilotFinishEvent, + createDecopilotThreadStatusEvent, + type ThreadStatus, +} from "@decocms/mesh-sdk"; import { consumeStream, createUIMessageStream, @@ -33,9 +39,17 @@ import { processConversation, splitRequestMessages, } from "./conversation"; -import { ensureOrganization, toolsFromMCP } from "./helpers"; -import { createMemory, Memory } from "./memory"; +import { + ensureOrganization, + toolsFromMCP, + validateThreadOwnership, +} from "./helpers"; +import { createMemory } from "./memory"; import { ensureModelCompatibility } from "./model-compat"; +import { sseHub } from "@/event-bus"; +import type { CancelBroadcast } from "./cancel-broadcast"; +import type { StreamBuffer } from "./stream-buffer"; +import type { RunRegistry } from "./run-registry"; import { checkModelPermission, fetchModelPermissions, @@ -46,6 +60,7 @@ import { StreamRequestSchema } from "./schemas"; import { resolveThreadStatus } from "./status"; import { genTitle } from "./title-generator"; import type { ChatMessage } from "./types"; +import { ThreadMessage } from "@/storage/types"; // ============================================================================ // Request Validation @@ -78,290 +93,352 @@ async function validateRequest( // Route Handler // ============================================================================ -const app = new Hono<{ Variables: { meshContext: MeshContext } }>(); +export interface DecopilotDeps { + cancelBroadcast: CancelBroadcast; + streamBuffer: StreamBuffer; + runRegistry: RunRegistry; +} -// ============================================================================ -// Allowed Models Endpoint -// ============================================================================ +export function createDecopilotRoutes(deps: DecopilotDeps) { + const { cancelBroadcast, streamBuffer, runRegistry } = deps; + const app = new Hono<{ Variables: { meshContext: MeshContext } }>(); + + // ============================================================================ + // Allowed Models Endpoint + // ============================================================================ -app.get("/:org/decopilot/allowed-models", async (c) => { - try { - const ctx = c.get("meshContext"); - const organization = ensureOrganization(c); - const role = ctx.auth.user?.role; + app.get("/:org/decopilot/allowed-models", async (c) => { + try { + const ctx = c.get("meshContext"); + const organization = ensureOrganization(c); + const role = ctx.auth.user?.role; - const models = await fetchModelPermissions(ctx.db, organization.id, role); + const models = await fetchModelPermissions(ctx.db, organization.id, role); - return c.json(parseModelsToMap(models)); - } catch (err) { - console.error("[decopilot:allowed-models] Error", err); - if (err instanceof HTTPException) { - return c.json({ error: err.message }, err.status); + return c.json(parseModelsToMap(models)); + } catch (err) { + console.error("[decopilot:allowed-models] Error", err); + if (err instanceof HTTPException) { + return c.json({ error: err.message }, err.status); + } + return c.json( + { error: err instanceof Error ? err.message : "Internal error" }, + 500, + ); } - return c.json( - { error: err instanceof Error ? err.message : "Internal error" }, - 500, - ); - } -}); + }); -// ============================================================================ -// Stream Endpoint -// ============================================================================ + // ============================================================================ + // Stream Endpoint + // ============================================================================ -app.post("/:org/decopilot/stream", async (c) => { - let memory: Memory | undefined; - try { - const ctx = c.get("meshContext"); - - // 1. Validate request - const { - organization, - models, - agent, - systemMessages, - requestMessage, - temperature, - memory: memoryConfig, - thread_id, - toolApprovalLevel, - } = await validateRequest(c); - - const userId = ctx.auth?.user?.id; - if (!userId) { - throw new HTTPException(401, { message: "User ID is required" }); - } + app.post("/:org/decopilot/stream", async (c) => { + let failThread: (() => void) | undefined; + let closeClients: (() => void) | undefined; + try { + const ctx = c.get("meshContext"); - // 2. Check model permissions - const allowedModels = await fetchModelPermissions( - ctx.db, - organization.id, - ctx.auth.user?.role, - ); + // 1. Validate request + const { + organization, + models, + agent, + systemMessages, + requestMessage, + temperature, + memory: memoryConfig, + thread_id, + toolApprovalLevel, + } = await validateRequest(c); + + const userId = ctx.auth?.user?.id; + if (!userId) { + throw new HTTPException(401, { message: "User ID is required" }); + } - if ( - !checkModelPermission( - allowedModels, - models.connectionId, - models.thinking.id, - ) - ) { - throw new HTTPException(403, { - message: "Model not allowed for your role", - }); - } + // 2. Check model permissions + const allowedModels = await fetchModelPermissions( + ctx.db, + organization.id, + ctx.auth.user?.role, + ); + + if ( + !checkModelPermission( + allowedModels, + models.connectionId, + models.thinking.id, + ) + ) { + throw new HTTPException(403, { + message: "Model not allowed for your role", + }); + } - const windowSize = memoryConfig?.windowSize ?? DEFAULT_WINDOW_SIZE; - const resolvedThreadId = thread_id ?? memoryConfig?.thread_id; - - // Get connection entities and create/load memory in parallel - const [virtualMcp, modelConnection, mem] = await Promise.all([ - ctx.storage.virtualMcps.findById(agent.id, organization.id), - ctx.storage.connections.findById(models.connectionId, organization.id), - createMemory(ctx.storage.threads, { - organization_id: organization.id, - thread_id: resolvedThreadId, - userId, - defaultWindowSize: windowSize, - }), - ]); - memory = mem; - - if (!modelConnection) { - throw new Error("Model connection not found"); - } + const windowSize = memoryConfig?.windowSize ?? DEFAULT_WINDOW_SIZE; + const resolvedThreadId = thread_id ?? memoryConfig?.thread_id; + + // Get connection entities and create/load memory in parallel + const [virtualMcp, modelConnection, mem] = await Promise.all([ + ctx.storage.virtualMcps.findById(agent.id, organization.id), + ctx.storage.connections.findById(models.connectionId, organization.id), + createMemory(ctx.storage.threads, { + organization_id: organization.id, + thread_id: resolvedThreadId, + userId, + defaultWindowSize: windowSize, + }), + ]); + const saveMessagesToThread = async ( + ...messages: (typeof requestMessage | undefined)[] + ) => { + const now = Date.now(); + const messagesToSave = [ + ...new Map(messages.filter(Boolean).map((m) => [m!.id, m!])).values(), + ].map((message, i) => ({ + ...message, + thread_id: mem.thread.id, + created_at: new Date(now + i).toISOString(), + updated_at: new Date(now + i).toISOString(), + })); + if (messagesToSave.length === 0) return; + await mem.save(messagesToSave as ThreadMessage[]).catch((error) => { + console.error("[decopilot:stream] Error saving messages", error); + }); + }; - if (!virtualMcp) { - throw new Error("Agent not found"); - } + const completeThread = (status: ThreadStatus) => { + ctx.storage.threads.update(mem.thread.id, { status }).catch((error) => { + console.error( + "[decopilot:stream] Error updating thread status", + error, + ); + }); + const runStatus = status === "completed" ? "completed" : "failed"; + runRegistry.finishRun(mem.thread.id, runStatus, (id) => + streamBuffer.purge(id), + ); + sseHub.emit( + organization.id, + createDecopilotThreadStatusEvent(mem.thread.id, status), + ); + sseHub.emit( + organization.id, + createDecopilotFinishEvent(mem.thread.id, status), + ); + }; + + failThread = () => completeThread("failed"); - // Mark thread as in_progress at the start of streaming - await ctx.storage.threads.update(mem.thread.id, { - status: "in_progress", - }); - - // Always create a passthrough client (all real tools) + model client. - // If mode is smart_tool_selection or code_execution, also create the strategy - // client so we get the gateway meta-tools (SEARCH/DESCRIBE/CALL_TOOL/RUN_CODE). - const isGatewayMode = agent.mode !== "passthrough"; - const [modelClient, passthroughClient, strategyClient] = await Promise.all([ - clientFromConnection(modelConnection, ctx, false), - createVirtualClientFrom(virtualMcp, ctx, "passthrough"), - isGatewayMode - ? createVirtualClientFrom(virtualMcp, ctx, agent.mode) - : Promise.resolve(null), - ]); - - // Add streaming support since agents may use streaming models - const streamableModelClient = withStreamingSupport( - modelClient, - models.connectionId, - modelConnection, - ctx, - { superUser: false }, - ); - - // Extract model provider (can stay outside execute) - const modelProvider = await createModelProviderFromClient( - streamableModelClient, - models, - ); - - // CRITICAL: Register abort handler to ensure client cleanup on disconnect - // Without this, when client disconnects mid-stream, onFinish/onError are NOT called - // and the MCP client + transport streams leak (TextDecoderStream, 256KB buffers) - const abortSignal = c.req.raw.signal; - abortSignal.addEventListener("abort", () => { - modelClient.close().catch(() => {}); - passthroughClient.close().catch(() => {}); - strategyClient?.close().catch(() => {}); - if (mem.thread.id) { - ctx.storage.threads - .update(mem.thread.id, { status: "failed" }) - .catch(() => {}); + if (!modelConnection) { + throw new Error("Model connection not found"); } - }); - - // Get server instructions if available (for virtual MCP agents) - const serverInstructions = passthroughClient.getInstructions(); - - // Merge platform instructions with request system messages - const systemPrompt = DECOPILOT_BASE_PROMPT(serverInstructions); - const allSystemMessages: ChatMessage[] = [systemPrompt, ...systemMessages]; - - const maxOutputTokens = - models.thinking.limits?.maxOutputTokens ?? DEFAULT_MAX_TOKENS; - - let streamFinished = false; - - const allMessages = await loadAndMergeMessages( - mem, - requestMessage, - allSystemMessages, - windowSize, - ); - - const toolOutputMap = new Map(); - // 4. Create stream with writer access for data parts - // IMPORTANT: Do NOT pass onFinish/onStepFinish to createUIMessageStream when - // using writer.merge with toUIMessageStream that has originalMessages. - // createUIMessageStream wraps its stream in handleUIMessageStreamFinish which - // runs processUIMessageStream on every chunk. Without originalMessages, the outer - // state starts with an empty assistant message, causing "No tool invocation found" - // errors when tool-output-available chunks arrive (e.g. after tool approval flow). - const uiStream = createUIMessageStream({ - execute: async ({ writer }) => { - // Create tools inside execute so they have access to writer - // Always get the full passthrough tools (all real tools from connections) - const passthroughTools = await toolsFromMCP( - passthroughClient, - toolOutputMap, - writer, - toolApprovalLevel, - ); - // If using a gateway mode, also get the strategy meta-tools - // (GATEWAY_SEARCH_TOOLS, GATEWAY_DESCRIBE_TOOLS, GATEWAY_CALL_TOOL / GATEWAY_RUN_CODE) - const strategyTools = strategyClient - ? await toolsFromMCP( - strategyClient, - toolOutputMap, - writer, + if (!virtualMcp) { + throw new Error("Agent not found"); + } + + // Mark thread as in_progress at the start of streaming + await ctx.storage.threads.update(mem.thread.id, { + status: "in_progress", + }); + sseHub.emit( + organization.id, + createDecopilotThreadStatusEvent(mem.thread.id, "in_progress"), + ); + + // Register run so it survives client disconnect; cancel uses run's AbortController + const run = runRegistry.startRun(mem.thread.id, organization.id, userId); + const abortSignal = run.abortController.signal; + + // Purge stale buffered chunks from any previous run on this thread + streamBuffer.purge(mem.thread.id); + + await saveMessagesToThread(requestMessage); + + // Always create a passthrough client (all real tools) + model client. + // If mode is smart_tool_selection or code_execution, also create the strategy + // client so we get the gateway meta-tools (SEARCH/DESCRIBE/CALL_TOOL/RUN_CODE). + const isGatewayMode = agent.mode !== "passthrough"; + const [modelClient, passthroughClient, strategyClient] = + await Promise.all([ + clientFromConnection(modelConnection, ctx, false), + createVirtualClientFrom(virtualMcp, ctx, "passthrough"), + isGatewayMode + ? createVirtualClientFrom(virtualMcp, ctx, agent.mode) + : Promise.resolve(null), + ]); + + closeClients = () => { + modelClient.close().catch(() => {}); + passthroughClient.close().catch(() => {}); + strategyClient?.close().catch(() => {}); + }; + + // Add streaming support since agents may use streaming models + const streamableModelClient = withStreamingSupport( + modelClient, + models.connectionId, + modelConnection, + ctx, + { superUser: false }, + ); + + // Extract model provider (can stay outside execute) + const modelProvider = await createModelProviderFromClient( + streamableModelClient, + models, + ); + + // MCP client cleanup on run abort (cancel from any pod), not request abort + abortSignal.addEventListener("abort", () => { + closeClients?.(); + failThread!(); + }); + + // Get server instructions if available (for virtual MCP agents) + const serverInstructions = passthroughClient.getInstructions(); + + // Merge platform instructions with request system messages + const systemPrompt = DECOPILOT_BASE_PROMPT(serverInstructions); + const allSystemMessages: ChatMessage[] = [ + systemPrompt, + ...systemMessages, + ]; + + const maxOutputTokens = + models.thinking.limits?.maxOutputTokens ?? DEFAULT_MAX_TOKENS; + + let streamFinished = false; + let stepCount = 0; + let pendingSave: Promise | null = null; + + const allMessages = await loadAndMergeMessages( + mem, + requestMessage, + allSystemMessages, + windowSize, + ); + + const toolOutputMap = new Map(); + // 4. Create stream with writer access for data parts + // Pass originalMessages so handleUIMessageStreamFinish (used by onFinish) + // can locate tool invocations from previous assistant messages during + // tool-approval continuation flows. + const uiStream = createUIMessageStream({ + originalMessages: allMessages, + execute: async ({ writer }) => { + // Create tools inside execute so they have access to writer + // Always get the full passthrough tools (all real tools from connections) + const passthroughTools = await toolsFromMCP( + passthroughClient, + toolOutputMap, + writer, + toolApprovalLevel, + ); + + // If using a gateway mode, also get the strategy meta-tools + // (GATEWAY_SEARCH_TOOLS, GATEWAY_DESCRIBE_TOOLS, GATEWAY_CALL_TOOL / GATEWAY_RUN_CODE) + const strategyTools = strategyClient + ? await toolsFromMCP( + strategyClient, + toolOutputMap, + writer, + toolApprovalLevel, + ) + : {}; + + const builtInTools = await getBuiltInTools( + writer, + { + modelProvider, + organization, + models, toolApprovalLevel, - ) - : {}; - - const builtInTools = await getBuiltInTools( - writer, - { - modelProvider, - organization, + toolOutputMap, + }, + ctx, + ); + + // Merge all tools: strategy meta-tools override passthrough tools with the same name, + // and built-in tools take final precedence. + const tools = { + ...passthroughTools, + ...strategyTools, + ...builtInTools, + }; + + // In gateway modes, only expose the strategy meta-tools + built-ins to the LLM. + // The passthrough tools are still registered (so the AI SDK won't throw if the + // model calls a discovered tool directly), but the LLM won't see their schemas. + const activeToolNames = strategyClient + ? ([ + ...Object.keys(strategyTools), + ...Object.keys(builtInTools), + ] as (keyof typeof tools)[]) + : undefined; + + // Process conversation with tools for validation + const { + systemMessages: processedSystemMessages, + messages: processedMessages, + originalMessages, + } = await processConversation(allMessages, { + windowSize, models, - toolApprovalLevel, - toolOutputMap, - }, - ctx, - ); + tools, + }); - // Merge all tools: strategy meta-tools override passthrough tools with the same name, - // and built-in tools take final precedence. - const tools = { - ...passthroughTools, - ...strategyTools, - ...builtInTools, - }; - - // In gateway modes, only expose the strategy meta-tools + built-ins to the LLM. - // The passthrough tools are still registered (so the AI SDK won't throw if the - // model calls a discovered tool directly), but the LLM won't see their schemas. - const activeToolNames = strategyClient - ? ([ - ...Object.keys(strategyTools), - ...Object.keys(builtInTools), - ] as (keyof typeof tools)[]) - : undefined; - - // Process conversation with tools for validation - const { - systemMessages: processedSystemMessages, - messages: processedMessages, - originalMessages, - } = await processConversation(allMessages, { - windowSize, - models, - tools, - }); + ensureModelCompatibility(models, originalMessages); - ensureModelCompatibility(models, originalMessages); + const shouldGenerateTitle = mem.thread.title === DEFAULT_THREAD_TITLE; + if (shouldGenerateTitle) { + genTitle({ + abortSignal, + model: modelProvider.fastModel ?? modelProvider.thinkingModel, + userMessage: JSON.stringify(processedMessages[0]?.content), + }).then(async (title) => { + if (!title) return; - const shouldGenerateTitle = mem.thread.title === DEFAULT_THREAD_TITLE; - if (shouldGenerateTitle) { - genTitle({ + await ctx.storage.threads + .update(mem.thread.id, { title }) + .catch((error) => { + console.error( + "[decopilot:stream] Error updating thread title", + error, + ); + }); + + if (!streamFinished) { + writer.write({ + type: "data-thread-title", + data: { title }, + transient: true, + }); + } + }); + } + + let reasoningStartAt: Date | null = null; + let lastProviderMetadata: Record | undefined; + + const result = streamText({ + model: modelProvider.thinkingModel, + system: processedSystemMessages, + messages: processedMessages, + tools, + activeTools: activeToolNames, + temperature, + maxOutputTokens, abortSignal, - model: modelProvider.fastModel ?? modelProvider.thinkingModel, - userMessage: JSON.stringify(processedMessages[0]?.content), - }).then(async (title) => { - if (!title) return; - - await ctx.storage.threads - .update(mem.thread.id, { title }) - .catch((error) => { - console.error( - "[decopilot:stream] Error updating thread title", - error, - ); - }); - - if (!streamFinished) { - writer.write({ - type: "data-thread-title", - data: { title }, - transient: true, - }); - } + stopWhen: stepCountIs(PARENT_STEP_LIMIT), + onError: async (error) => { + console.error("[decopilot:stream] Error", error); + throw error; + }, }); - } - - let reasoningStartAt: Date | null = null; - let lastProviderMetadata: Record | undefined; - const result = streamText({ - model: modelProvider.thinkingModel, - system: processedSystemMessages, - messages: processedMessages, - tools, - activeTools: activeToolNames, - temperature, - maxOutputTokens, - abortSignal, - stopWhen: stepCountIs(PARENT_STEP_LIMIT), - onError: async (error) => { - console.error("[decopilot:stream] Error", error); - throw error; - }, - }); - writer.merge( - result.toUIMessageStream({ + const uiMessageStream = result.toUIMessageStream({ originalMessages, generateMessageId, messageMetadata: ({ part }) => { @@ -425,108 +502,151 @@ app.post("/:org/decopilot/stream", async (c) => { return; }, - onFinish: async ({ responseMessage }) => { - streamFinished = true; - - const now = Date.now(); - const messagesToSave = [ - ...new Map( - [requestMessage, responseMessage] - .filter(Boolean) - .map((m) => [m.id, m]), - ).values(), - ].map((message, i) => ({ - ...message, - thread_id: mem.thread.id, - created_at: new Date(now + i).toISOString(), - updated_at: new Date(now + i).toISOString(), - })); - - if (messagesToSave.length === 0) return; - - await mem.save(messagesToSave).catch((error) => { - console.error( - "[decopilot:stream] Error saving messages", - error, - ); - }); - - // Determine and persist thread status - const finishReason = await result.finishReason; - const threadStatus = resolveThreadStatus( - finishReason, - responseMessage?.parts ?? [], - ); + }); - await ctx.storage.threads - .update(mem.thread.id, { status: threadStatus }) - .catch((error) => { - console.error( - "[decopilot:stream] Error updating thread status", - error, - ); - }); - }, - }), - ); - }, - onError: (error) => { - streamFinished = true; - console.error("[decopilot] stream error:", error); - - if (mem.thread.id) { - ctx.storage.threads - .update(mem.thread.id, { status: "failed" }) - .catch((statusErr) => { - console.error( - "[decopilot:stream] Error updating thread status on stream error", - statusErr, - ); - }); - } - - return error instanceof Error ? error.message : String(error); - }, - }); - - return createUIMessageStreamResponse({ - stream: uiStream, - consumeSseStream: consumeStream, - }); - } catch (err) { - // If we have a thread, mark it as failed - if (memory) { - const ctx = c.get("meshContext"); - await ctx.storage.threads - .update(memory.thread.id, { status: "failed" }) - .catch((statusErr: unknown) => { - console.error( - "[decopilot:stream] Failed to update thread status", - statusErr, + writer.merge( + streamBuffer.relay(uiMessageStream, mem.thread.id, abortSignal), + ); + }, + onFinish: async ({ responseMessage, finishReason }) => { + streamFinished = true; + closeClients?.(); + + if (pendingSave) await pendingSave; + await saveMessagesToThread(responseMessage); + + // Abort listener already called failThread(); skip status update + if (abortSignal.aborted) return; + + const threadStatus = resolveThreadStatus( + finishReason, + responseMessage?.parts as { + type: string; + state?: string; + text?: string; + }[], ); - }); - } - console.error("[decopilot:stream] Error", err); + completeThread(threadStatus); + }, + onStepFinish: ({ responseMessage }) => { + stepCount++; + sseHub.emit( + organization.id, + createDecopilotStepEvent(mem.thread.id, stepCount), + ); + if (stepCount % 5 === 0) { + pendingSave = saveMessagesToThread(responseMessage).finally(() => { + pendingSave = null; + }); + } + }, + onError: (error) => { + streamFinished = true; + closeClients?.(); + if (abortSignal.aborted) + return error instanceof Error ? error.message : String(error); + console.error("[decopilot] stream error:", error); + + if (mem.thread.id) { + failThread!(); + } + + return error instanceof Error ? error.message : String(error); + }, + }); + + return createUIMessageStreamResponse({ + stream: uiStream, + consumeSseStream: consumeStream, + }); + } catch (err) { + closeClients?.(); + if (failThread) { + failThread(); + } + + console.error("[decopilot:stream] Error", err); + + if (err instanceof HTTPException) { + return c.json({ error: err.message }, err.status); + } - if (err instanceof HTTPException) { - return c.json({ error: err.message }, err.status); + if (err instanceof Error && err.name === "AbortError") { + console.warn("[decopilot:stream] Aborted", { error: err.message }); + return c.json({ error: "Request aborted" }, 400); + } + + console.error("[decopilot:stream] Failed", { + error: err instanceof Error ? err.message : JSON.stringify(err), + stack: err instanceof Error ? err.stack : undefined, + }); + return c.json( + { error: err instanceof Error ? err.message : JSON.stringify(err) }, + 500, + ); } + }); - if (err instanceof Error && err.name === "AbortError") { - console.warn("[decopilot:stream] Aborted", { error: err.message }); - return c.json({ error: "Request aborted" }, 400); + // ============================================================================ + // Cancel Endpoint — cancel ongoing run (local or via NATS to owning pod) + // ============================================================================ + + app.post("/:org/decopilot/cancel/:threadId", async (c) => { + const { threadId } = await validateThreadOwnership(c); + + if (runRegistry.cancelLocal(threadId)) { + return c.json({ cancelled: true }); } - console.error("[decopilot:stream] Failed", { - error: err instanceof Error ? err.message : JSON.stringify(err), - stack: err instanceof Error ? err.stack : undefined, - }); - return c.json( - { error: err instanceof Error ? err.message : JSON.stringify(err) }, - 500, - ); - } -}); + // Not on this pod — broadcast to all pods + cancelBroadcast.broadcast(threadId); + return c.json({ cancelled: true, async: true }, 202); + }); -export default app; + // ============================================================================ + // Attach Endpoint — replay JetStream-buffered stream for late-joining clients + // ============================================================================ + + app.get("/:org/decopilot/attach/:threadId", async (c) => { + try { + const { threadId } = await validateThreadOwnership(c); + + const run = runRegistry.getRun(threadId); + if (!run || run.status !== "running") { + return c.body(null, 204); + } + + const replayChunkStream = await streamBuffer.createReplayStream(threadId); + if (!replayChunkStream) { + return c.body(null, 204); + } + + const replayStream = createUIMessageStream({ + execute: async ({ writer }) => { + const reader = replayChunkStream.getReader(); + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + writer.write(value); + } + } finally { + reader.releaseLock(); + } + }, + }); + + return createUIMessageStreamResponse({ + stream: replayStream, + consumeSseStream: consumeStream, + }); + } catch (err) { + if (err instanceof HTTPException) throw err; + console.error("[decopilot:attach] Error", err); + return c.body(null, 500); + } + }); + + return app; +} diff --git a/apps/mesh/src/api/routes/decopilot/run-registry.test.ts b/apps/mesh/src/api/routes/decopilot/run-registry.test.ts new file mode 100644 index 0000000000..560edb8e15 --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/run-registry.test.ts @@ -0,0 +1,172 @@ +import { describe, it, expect, mock } from "bun:test"; +import type { ThreadStoragePort } from "@/storage/ports"; +import { RunRegistry } from "./run-registry"; + +function mockStorage(): ThreadStoragePort { + return { + update: mock(() => Promise.resolve({} as never)), + create: mock(() => Promise.resolve({} as never)), + get: mock(() => Promise.resolve(null)), + delete: mock(() => Promise.resolve()), + list: mock(() => Promise.resolve({ threads: [], total: 0 })), + saveMessages: mock(() => Promise.resolve()), + listMessages: mock(() => Promise.resolve({ messages: [], total: 0 })), + }; +} + +describe("RunRegistry", () => { + function createRegistry() { + return new RunRegistry(); + } + + describe("startRun", () => { + it("creates a new run with correct fields and running status", () => { + const registry = createRegistry(); + const run = registry.startRun("t1", "org1", "u1"); + + expect(run.threadId).toBe("t1"); + expect(run.orgId).toBe("org1"); + expect(run.userId).toBe("u1"); + expect(run.status).toBe("running"); + expect(run.abortController).toBeInstanceOf(AbortController); + expect(run.abortController.signal.aborted).toBe(false); + expect(run.startedAt).toBeInstanceOf(Date); + }); + + it("aborts and replaces an existing running entry for the same threadId", () => { + const registry = createRegistry(); + const first = registry.startRun("t1", "org1", "u1"); + const firstAbort = first.abortController; + + const second = registry.startRun("t1", "org2", "u2"); + + expect(firstAbort.signal.aborted).toBe(true); + expect(first.status).toBe("failed"); + expect(second.threadId).toBe("t1"); + expect(second.orgId).toBe("org2"); + expect(second.status).toBe("running"); + expect(registry.getRun("t1")).toBe(second); + }); + + it("replaces a non-running entry without aborting", () => { + const registry = createRegistry(); + const first = registry.startRun("t1", "org1", "u1"); + registry.completeRun("t1", "completed"); + const firstAbort = first.abortController; + + registry.startRun("t1", "org1", "u1"); + + expect(firstAbort.signal.aborted).toBe(false); + }); + }); + + describe("getRun", () => { + it("returns the run for an existing threadId", () => { + const registry = createRegistry(); + const run = registry.startRun("t1", "org1", "u1"); + expect(registry.getRun("t1")).toBe(run); + }); + + it("returns undefined for a non-existent threadId", () => { + const registry = createRegistry(); + expect(registry.getRun("nope")).toBeUndefined(); + }); + }); + + describe("cancelLocal", () => { + it("returns true and aborts a running entry", () => { + const registry = createRegistry(); + const run = registry.startRun("t1", "org1", "u1"); + const result = registry.cancelLocal("t1"); + + expect(result).toBe(true); + expect(run.status).toBe("failed"); + expect(run.abortController.signal.aborted).toBe(true); + }); + + it("returns false for a non-existent threadId", () => { + const registry = createRegistry(); + expect(registry.cancelLocal("nope")).toBe(false); + }); + + it("returns false for a non-running entry", () => { + const registry = createRegistry(); + registry.startRun("t1", "org1", "u1"); + registry.completeRun("t1", "completed"); + + expect(registry.cancelLocal("t1")).toBe(false); + }); + }); + + describe("completeRun", () => { + it("sets status and deletes from the map", () => { + const registry = createRegistry(); + const run = registry.startRun("t1", "org1", "u1"); + registry.completeRun("t1", "completed"); + + expect(run.status).toBe("completed"); + expect(registry.getRun("t1")).toBeUndefined(); + }); + + it("is a no-op for a non-existent threadId", () => { + const registry = createRegistry(); + registry.completeRun("no-such-thread", "failed"); + expect(registry.getRun("no-such-thread")).toBeUndefined(); + }); + }); + + describe("finishRun", () => { + it("calls completeRun and invokes onPurge callback", () => { + const registry = createRegistry(); + registry.startRun("t1", "org1", "u1"); + const purged: string[] = []; + + registry.finishRun("t1", "completed", (id) => purged.push(id)); + + expect(registry.getRun("t1")).toBeUndefined(); + expect(purged).toEqual(["t1"]); + }); + + it("works without onPurge callback", () => { + const registry = createRegistry(); + const run = registry.startRun("t1", "org1", "u1"); + registry.finishRun("t1", "failed"); + + expect(run.status).toBe("failed"); + expect(registry.getRun("t1")).toBeUndefined(); + }); + + it("is a no-op for non-existent threadId (no throw)", () => { + const registry = createRegistry(); + const purged: string[] = []; + registry.finishRun("no-such-thread", "failed", (id) => purged.push(id)); + expect(purged).toEqual(["no-such-thread"]); + }); + }); + + describe("stopAll", () => { + it("aborts all running entries and clears the map", () => { + const registry = createRegistry(); + const storage = mockStorage(); + const run1 = registry.startRun("t1", "org1", "u1"); + const run2 = registry.startRun("t2", "org1", "u2"); + + const completedRun = registry.startRun("t3", "org1", "u3"); + completedRun.status = "completed" as const; + + registry.stopAll(storage); + + expect(run1.abortController.signal.aborted).toBe(true); + expect(run2.abortController.signal.aborted).toBe(true); + expect(completedRun.abortController.signal.aborted).toBe(false); + + expect(storage.update).toHaveBeenCalledTimes(2); + expect(storage.update).toHaveBeenCalledWith("t1", { status: "failed" }); + expect(storage.update).toHaveBeenCalledWith("t2", { status: "failed" }); + + expect(registry.getRun("t1")).toBeUndefined(); + expect(registry.getRun("t2")).toBeUndefined(); + expect(registry.getRun("t3")).toBeUndefined(); + }); + }); +}); diff --git a/apps/mesh/src/api/routes/decopilot/run-registry.ts b/apps/mesh/src/api/routes/decopilot/run-registry.ts new file mode 100644 index 0000000000..c839e7f545 --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/run-registry.ts @@ -0,0 +1,120 @@ +/** + * RunRegistry — in-memory registry of active Decopilot runs + * + * Tracks running streamText loops by threadId so they survive client disconnect. + * Cancel is propagated via NATS to the pod that owns the run. + */ + +import type { ThreadStoragePort } from "@/storage/ports"; + +export interface ActiveRun { + threadId: string; + orgId: string; + userId: string; + abortController: AbortController; + status: "running" | "completed" | "failed"; + startedAt: Date; +} + +const REAP_INTERVAL_MS = 5 * 60 * 1000; // 5 minutes +const MAX_RUN_AGE_MS = 30 * 60 * 1000; // 30 minutes + +export class RunRegistry { + private readonly runs = new Map(); + private reaperTimer: ReturnType | null = null; + + constructor() { + this.reaperTimer = setInterval( + () => this.reapStaleRuns(), + REAP_INTERVAL_MS, + ); + } + + private reapStaleRuns(): void { + const now = Date.now(); + for (const [threadId, run] of this.runs) { + if ( + run.status === "running" && + now - run.startedAt.getTime() > MAX_RUN_AGE_MS + ) { + console.warn( + `[RunRegistry] Reaping stale run for thread ${threadId} (age: ${Math.round((now - run.startedAt.getTime()) / 60_000)}min)`, + ); + run.status = "failed"; + run.abortController.abort(); + this.runs.delete(threadId); + } + } + } + + startRun(threadId: string, orgId: string, userId: string): ActiveRun { + const existing = this.runs.get(threadId); + if (existing) { + if (existing.status === "running") { + existing.abortController.abort(); + } + existing.status = "failed"; + this.runs.delete(threadId); + } + const run: ActiveRun = { + threadId, + orgId, + userId, + abortController: new AbortController(), + status: "running", + startedAt: new Date(), + }; + this.runs.set(threadId, run); + return run; + } + + getRun(threadId: string): ActiveRun | undefined { + return this.runs.get(threadId); + } + + cancelLocal(threadId: string): boolean { + const run = this.runs.get(threadId); + if (!run || run.status !== "running") return false; + run.status = "failed"; + run.abortController.abort(); + return true; + } + + completeRun(threadId: string, status: "completed" | "failed"): void { + const run = this.runs.get(threadId); + if (run) { + run.status = status; + this.runs.delete(threadId); + } + } + + /** + * Finish a run: update status, remove from registry, and purge stream buffer. + * Unifies completeRun + purge into a single call to avoid split call sites. + */ + finishRun( + threadId: string, + status: "completed" | "failed", + onPurge?: (threadId: string) => void, + ): void { + this.completeRun(threadId, status); + onPurge?.(threadId); + } + + stopAll(storage: ThreadStoragePort): void { + for (const [threadId, run] of this.runs) { + if (run.status === "running") { + run.abortController.abort(); + storage.update(threadId, { status: "failed" }).catch(() => {}); + } + } + this.runs.clear(); + } + + dispose(): void { + if (this.reaperTimer) { + clearInterval(this.reaperTimer); + this.reaperTimer = null; + } + } +} diff --git a/apps/mesh/src/api/routes/decopilot/status.test.ts b/apps/mesh/src/api/routes/decopilot/status.test.ts index bd5c12e450..b8ac879792 100644 --- a/apps/mesh/src/api/routes/decopilot/status.test.ts +++ b/apps/mesh/src/api/routes/decopilot/status.test.ts @@ -6,75 +6,11 @@ describe("resolveThreadStatus", () => { expect(resolveThreadStatus("stop", [])).toBe("completed"); }); - test("stop with last text part containing ? -> requires_action", () => { + test("stop always returns completed regardless of text content", () => { const parts = [ { type: "text", text: "Here is the answer." }, { type: "text", text: "Does that help?" }, ]; - expect(resolveThreadStatus("stop", parts)).toBe("requires_action"); - }); - - test("stop with last text part not containing ? -> completed", () => { - const parts = [{ type: "text", text: "Here is the answer." }]; - expect(resolveThreadStatus("stop", parts)).toBe("completed"); - }); - - test("stop with last text part (after non-text) containing ? -> requires_action", () => { - const parts = [ - { type: "text", text: "Done." }, - { type: "tool-invocation", toolName: "x", state: "result" }, - { type: "text", text: "Want more?" }, - ]; - expect(resolveThreadStatus("stop", parts)).toBe("requires_action"); - }); - - test("stop with URL containing query string in last text part -> completed", () => { - const parts = [ - { - type: "text", - text: "Check this link: https://example.com/api?foo=bar&baz=qux", - }, - ]; - expect(resolveThreadStatus("stop", parts)).toBe("completed"); - }); - - test("stop with inline code containing ? (ternary) in last text part -> completed", () => { - const parts = [ - { - type: "text", - text: "Use a ternary: `x ? y : z` for that.", - }, - ]; - expect(resolveThreadStatus("stop", parts)).toBe("completed"); - }); - - test("stop with fenced code block containing ? in last text part -> completed", () => { - const parts = [ - { - type: "text", - text: "Here's the code:\n\n```js\nconst x = a ? b : c;\n```\n\nDone.", - }, - ]; - expect(resolveThreadStatus("stop", parts)).toBe("completed"); - }); - - test("stop with URL and real question in last text part -> requires_action", () => { - const parts = [ - { - type: "text", - text: "See https://example.com?ref=1 for details. Does that help?", - }, - ]; - expect(resolveThreadStatus("stop", parts)).toBe("requires_action"); - }); - - test("stop with markdown image containing pre-signed S3 URL -> completed", () => { - const parts = [ - { - type: "text", - text: "Perfect! I've generated an image of a capybara having ice cream for you! \n\n![Capybara enjoying ice cream](https://deco-chat-shared-deco-team.c95fc4cec7fc52453228d9db170c372c.r2.cloudflarestorage.com//images/2026-02-18T16-25-14-100Z.png?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=12fd512fec8b8158e9e414db6675a3d9%2F20260218%2Fauto%2Fs3%2Faws4_request&X-Amz-Date=20260218T162516Z&X-Amz-Expires=3600&X-Amz-Signature=d7372684ded0dd344372e83b7c1953192cb498a697ae7dd713b24cb4c6f16c20&X-Amz-SignedHeaders=host&x-amz-checksum-mode=ENABLED&x-id=GetObject)\n\nHere's your adorable capybara enjoying some ice cream! 🍦", - }, - ]; expect(resolveThreadStatus("stop", parts)).toBe("completed"); }); diff --git a/apps/mesh/src/api/routes/decopilot/status.ts b/apps/mesh/src/api/routes/decopilot/status.ts index c44c2a850d..1a2639b521 100644 --- a/apps/mesh/src/api/routes/decopilot/status.ts +++ b/apps/mesh/src/api/routes/decopilot/status.ts @@ -14,22 +14,6 @@ type ResponsePart = { state?: string; }; -/** - * Returns true if the text contains a direct question to the user (sentence-ending ?). - * Strips URLs, code blocks, and inline code to avoid false positives from query strings, - * ternary operators, regex literals, etc. - */ -function hasDirectQuestion(text: string): boolean { - const sanitized = text - .replace(/```[\s\S]*?```/g, "") - .replace(/`[^`]*`/g, "") - .replace(/https?:\/\/\S+/g, "") - .replace(/www\.\S+/g, ""); - - const lastParagraph = sanitized.split(/\n\s*\n/).at(-1) ?? sanitized; - return /\?(\s|[)"'\]},]|$)/m.test(lastParagraph); -} - /** * Resolves the thread status from the AI SDK stream result. * @@ -42,11 +26,6 @@ export function resolveThreadStatus( responseParts: ResponsePart[] = [], ): ThreadStatus { if (finishReason === "stop") { - // Question in last text part -> waiting for user answer - const lastTextPart = responseParts.findLast((p) => p.type === "text"); - if (lastTextPart?.text && hasDirectQuestion(lastTextPart.text)) { - return "requires_action"; - } return "completed"; } diff --git a/apps/mesh/src/api/routes/decopilot/stream-buffer.ts b/apps/mesh/src/api/routes/decopilot/stream-buffer.ts new file mode 100644 index 0000000000..6bb2a31bdc --- /dev/null +++ b/apps/mesh/src/api/routes/decopilot/stream-buffer.ts @@ -0,0 +1,63 @@ +/** + * Stream Buffer Interface + * + * Abstraction for buffering UIMessageStream chunks so late-joining + * clients can replay them from any pod. + * + * In single-process mode without NATS/JetStream, the buffer is a no-op + * (late-join is disabled, relay is a passthrough). + * + * Mirrors the SSEBroadcastStrategy / CancelBroadcast pattern. + */ + +/** + * StreamBuffer allows buffering and replaying UIMessageStream chunks + * for late-joining clients (the /attach endpoint). + */ +export interface StreamBuffer { + /** Initialize the buffer (e.g., ensure JetStream stream exists). */ + init(): Promise; + + /** + * Wrap a ReadableStream so every chunk is also buffered. + * Returns a new stream that passes through all chunks unchanged. + * If the buffer is unavailable, returns the original stream as-is. + */ + relay( + stream: ReadableStream, + threadId: string, + abortSignal?: AbortSignal, + ): ReadableStream; + + /** + * Create a replay stream for a late-joining client. + * Returns null if buffering is not available or the thread has no data. + */ + createReplayStream(threadId: string): Promise; + + /** Purge buffered data for a thread (best-effort, fire-and-forget). */ + purge(threadId: string): void; + + /** Release resources (clear references, called on shutdown). */ + teardown(): void; +} + +/** + * No-op stream buffer — late-join disabled, relay is passthrough. + * Used when NATS/JetStream is not configured. + */ +export class NoOpStreamBuffer implements StreamBuffer { + async init(): Promise {} + + relay(stream: ReadableStream): ReadableStream { + return stream; + } + + async createReplayStream(): Promise { + return null; + } + + purge(): void {} + + teardown(): void {} +} diff --git a/apps/mesh/src/event-bus/index.ts b/apps/mesh/src/event-bus/index.ts index d4e1b9dafb..7ab5f142e2 100644 --- a/apps/mesh/src/event-bus/index.ts +++ b/apps/mesh/src/event-bus/index.ts @@ -22,6 +22,7 @@ * ``` */ +import type { NatsConnectionProvider } from "../nats/connection"; import type { MeshDatabase } from "../database"; import { createEventBusStorage } from "../storage/event-bus"; import { EventBus as EventBusImpl } from "./event-bus"; @@ -111,11 +112,13 @@ function resolveNotifyStrategy(database: MeshDatabase): NotifyStrategyName { * * @param database - MeshDatabase instance (discriminated union) * @param config - Optional event bus configuration + * @param natsProvider - Optional shared NATS connection provider (when using NATS strategies) * @returns EventBus instance */ export function createEventBus( database: MeshDatabase, config?: EventBusConfig, + natsProvider?: NatsConnectionProvider | null, ): EventBus { const storage = createEventBusStorage(database.db); const pollIntervalMs = @@ -140,10 +143,19 @@ export function createEventBus( return "unknown"; } })(); + if (!natsProvider) { + console.warn( + `[EventBus] NATS unavailable (${natsHost}), falling back to polling`, + ); + notifyStrategy = polling; + break; + } console.log(`[EventBus] Using NATS notify strategy (${natsHost})`); notifyStrategy = compose( polling, - new NatsNotifyStrategy({ servers: natsUrl }), + new NatsNotifyStrategy({ + getConnection: () => natsProvider!.getConnection(), + }), ); break; } @@ -170,15 +182,18 @@ export function createEventBus( // Start SSE hub with the appropriate broadcast strategy. // NATS available → cross-pod fan-out; otherwise → local only. - const sseBroadcast = natsUrl - ? new NatsSSEBroadcast({ servers: natsUrl }) - : new LocalSSEBroadcast(); + const sseBroadcast = + natsUrl && natsProvider + ? new NatsSSEBroadcast({ + getConnection: () => natsProvider!.getConnection(), + }) + : new LocalSSEBroadcast(); sseHub.start(sseBroadcast).catch((err) => { console.error("[SSEHub] Failed to start broadcast strategy:", err); }); - if (natsUrl) { + if (natsUrl && natsProvider) { console.log("[SSEHub] Using NATS SSE broadcast (cross-pod)"); } else { console.log("[SSEHub] Using local SSE broadcast (single-pod)"); diff --git a/apps/mesh/src/event-bus/nats-notify.ts b/apps/mesh/src/event-bus/nats-notify.ts index 51c74f4cbf..1f5e527df9 100644 --- a/apps/mesh/src/event-bus/nats-notify.ts +++ b/apps/mesh/src/event-bus/nats-notify.ts @@ -7,35 +7,32 @@ * Architecture: * - `notify()`: Publishes to a NATS subject * - `start()`: Subscribes to the subject and calls onNotify() on each message - * - Reconnection is handled transparently by the nats.js client + * - Connection is provided by NatsConnectionProvider (does not own/drain) */ -import { connect, type NatsConnection, type Subscription } from "nats"; +import type { NatsConnection, Subscription } from "nats"; import type { NotifyStrategy } from "./notify-strategy"; const SUBJECT = "mesh.events.notify"; export interface NatsNotifyStrategyOptions { - /** NATS server URL(s), e.g. "nats://localhost:4222" */ - servers: string | string[]; + getConnection: () => NatsConnection | null; } export class NatsNotifyStrategy implements NotifyStrategy { - private nc: NatsConnection | null = null; private sub: Subscription | null = null; private onNotify: (() => void) | null = null; + private readonly encoder = new TextEncoder(); constructor(private readonly options: NatsNotifyStrategyOptions) {} async start(onNotify: () => void): Promise { - if (this.nc) return; // Already started + const nc = this.options.getConnection(); + if (!nc || this.sub) return; this.onNotify = onNotify; - this.nc = await connect({ servers: this.options.servers }); + this.sub = nc.subscribe(SUBJECT); - this.sub = this.nc.subscribe(SUBJECT); - - // Process messages in background — each message wakes the worker (async () => { for await (const _msg of this.sub!) { this.onNotify?.(); @@ -51,22 +48,16 @@ export class NatsNotifyStrategy implements NotifyStrategy { this.sub?.unsubscribe(); this.sub = null; this.onNotify = null; - - if (this.nc) { - await this.nc.drain(); - this.nc = null; - } - console.log("[NatsNotify] Stopped"); } async notify(eventId: string): Promise { - if (!this.nc) return; + const nc = this.options.getConnection(); + if (!nc) return; try { - this.nc.publish(SUBJECT, new TextEncoder().encode(eventId)); + nc.publish(SUBJECT, this.encoder.encode(eventId)); } catch (err) { - // Non-critical — polling safety net will still pick it up console.warn("[NatsNotify] Publish failed (non-critical):", err); } } diff --git a/apps/mesh/src/event-bus/nats-sse-broadcast.ts b/apps/mesh/src/event-bus/nats-sse-broadcast.ts index a4a8bb348d..b93294682d 100644 --- a/apps/mesh/src/event-bus/nats-sse-broadcast.ts +++ b/apps/mesh/src/event-bus/nats-sse-broadcast.ts @@ -6,9 +6,10 @@ * when it receives a message, so SSE clients on every pod get the event. * * Uses a per-instance origin ID to avoid double-emitting on the publisher pod. + * Connection is provided by NatsConnectionProvider (does not own/drain). */ -import { connect, type NatsConnection, type Subscription } from "nats"; +import type { NatsConnection, Subscription } from "nats"; import type { SSEEvent } from "./sse-hub"; import type { LocalEmitFn, @@ -24,35 +25,23 @@ interface NatsSSEMessage { } export interface NatsSSEBroadcastOptions { - servers: string | string[]; + getConnection: () => NatsConnection | null; } export class NatsSSEBroadcast implements SSEBroadcastStrategy { - private nc: NatsConnection | null = null; private sub: Subscription | null = null; private localEmit: LocalEmitFn | null = null; - private startPromise: Promise | null = null; private readonly originId = crypto.randomUUID(); private readonly encoder = new TextEncoder(); constructor(private readonly options: NatsSSEBroadcastOptions) {} async start(localEmit: LocalEmitFn): Promise { - if (this.nc) return; - if (this.startPromise) return this.startPromise; - - this.startPromise = this._doStart(localEmit); - try { - await this.startPromise; - } finally { - this.startPromise = null; - } - } - - private async _doStart(localEmit: LocalEmitFn): Promise { this.localEmit = localEmit; - this.nc = await connect({ servers: this.options.servers }); - this.sub = this.nc.subscribe(SUBJECT); + + const nc = this.options.getConnection(); + if (!nc || this.sub) return; + this.sub = nc.subscribe(SUBJECT); const decoder = new TextDecoder(); @@ -82,10 +71,10 @@ export class NatsSSEBroadcast implements SSEBroadcastStrategy { } broadcast(organizationId: string, event: SSEEvent): void { - // Always emit locally first (fast path for SSE clients on this pod) this.localEmit?.(organizationId, event); - if (!this.nc) return; + const nc = this.options.getConnection(); + if (!nc) return; const payload: NatsSSEMessage = { originId: this.originId, @@ -94,7 +83,7 @@ export class NatsSSEBroadcast implements SSEBroadcastStrategy { }; try { - this.nc.publish(SUBJECT, this.encoder.encode(JSON.stringify(payload))); + nc.publish(SUBJECT, this.encoder.encode(JSON.stringify(payload))); } catch (err) { console.warn("[NatsSSEBroadcast] Publish failed (non-critical):", err); } @@ -104,12 +93,6 @@ export class NatsSSEBroadcast implements SSEBroadcastStrategy { this.sub?.unsubscribe(); this.sub = null; this.localEmit = null; - - if (this.nc) { - await this.nc.drain(); - this.nc = null; - } - console.log("[NatsSSEBroadcast] Stopped"); } } diff --git a/apps/mesh/src/nats/connection.test.ts b/apps/mesh/src/nats/connection.test.ts new file mode 100644 index 0000000000..f65e3be4b6 --- /dev/null +++ b/apps/mesh/src/nats/connection.test.ts @@ -0,0 +1,30 @@ +import { describe, it, expect } from "bun:test"; +import { createNatsConnectionProvider } from "./connection"; + +// Mock the nats `connect` function at the module level is impractical +// without the real NATS server, so we test the structural guarantees: +// idempotent init, getConnection/getJetStream null before init, drain clears state. + +describe("createNatsConnectionProvider (unit)", () => { + it("getConnection returns null before init", () => { + const provider = createNatsConnectionProvider(); + expect(provider.getConnection()).toBeNull(); + }); + + it("getJetStream returns null before init", () => { + const provider = createNatsConnectionProvider(); + expect(provider.getJetStream()).toBeNull(); + }); + + it("drain is safe to call before init (no throw)", async () => { + const provider = createNatsConnectionProvider(); + await expect(provider.drain()).resolves.toBeUndefined(); + }); + + it("drain clears state so getConnection returns null after drain", async () => { + const provider = createNatsConnectionProvider(); + await provider.drain(); + expect(provider.getConnection()).toBeNull(); + expect(provider.getJetStream()).toBeNull(); + }); +}); diff --git a/apps/mesh/src/nats/connection.ts b/apps/mesh/src/nats/connection.ts new file mode 100644 index 0000000000..4f0a8b47ab --- /dev/null +++ b/apps/mesh/src/nats/connection.ts @@ -0,0 +1,66 @@ +/** + * Shared NATS Connection Provider + * + * Manages a single NATS connection shared by all NATS implementations: + * - NatsCancelBroadcast (decopilot cancel) + * - NatsStreamBuffer (decopilot JetStream relay) + * - NatsNotifyStrategy (event bus wake-up) + * - NatsSSEBroadcast (cross-pod SSE fan-out) + * + * Benefits: + * - Single connection to NATS server (recommended best practice) + * - One place for reconnect logic and error handling + * - Clear shutdown ordering (drain shared connection last) + */ + +import { connect, type JetStreamClient, type NatsConnection } from "nats"; + +export interface NatsConnectionProvider { + /** Connect to NATS eagerly. Fails fast if unreachable. */ + init(url: string | string[]): Promise; + /** Returns the shared connection, or null if not initialized. */ + getConnection(): NatsConnection | null; + /** Returns a JetStream client from the shared connection, or null. */ + getJetStream(): JetStreamClient | null; + /** Drain the connection. Call after all consumers have stopped. */ + drain(): Promise; +} + +/** + * Create a NatsConnectionProvider instance. + * Typically one per process. + */ +export function createNatsConnectionProvider(): NatsConnectionProvider { + let nc: NatsConnection | null = null; + let js: JetStreamClient | null = null; + + return { + async init(url: string | string[]): Promise { + if (nc) return; + nc = await connect({ servers: url }); + console.log("[NATS] Connected"); + }, + + getConnection(): NatsConnection | null { + return nc; + }, + + getJetStream(): JetStreamClient | null { + if (!nc) return null; + if (!js) { + js = nc.jetstream(); + } + return js; + }, + + async drain(): Promise { + js = null; + if (nc) { + const conn = nc; + nc = null; + await conn.drain().catch(() => {}); + console.log("[NATS] Connection drained"); + } + }, + }; +} diff --git a/apps/mesh/src/storage/types.ts b/apps/mesh/src/storage/types.ts index 7f27f59341..0528787707 100644 --- a/apps/mesh/src/storage/types.ts +++ b/apps/mesh/src/storage/types.ts @@ -14,6 +14,7 @@ import type { ColumnType } from "kysely"; import type { OAuthConfig, ToolDefinition } from "../tools/connection/schema"; import type { ChatMessage } from "../api/routes/decopilot/types"; +import { ThreadStatus } from "@decocms/mesh-sdk"; // ============================================================================ // Type Utilities @@ -691,14 +692,11 @@ export interface ConnectionAggregationTable { * Threads are scopes users in organizations and store messages with Agents. */ -/** Stored thread statuses (persisted in DB). */ -export const THREAD_STATUSES = [ - "in_progress", - "requires_action", - "failed", - "completed", -] as const; -export type ThreadStatus = (typeof THREAD_STATUSES)[number]; +/** Stored thread statuses (persisted in DB). Canonical source: @decocms/mesh-sdk */ +export { + THREAD_STATUSES, + type ThreadStatus, +} from "@decocms/mesh-sdk"; export interface ThreadTable { id: string; diff --git a/apps/mesh/src/web/components/chat/chat-state.ts b/apps/mesh/src/web/components/chat/chat-state.ts new file mode 100644 index 0000000000..5508c831b5 --- /dev/null +++ b/apps/mesh/src/web/components/chat/chat-state.ts @@ -0,0 +1,63 @@ +/** + * Chat state reducer and types + * + * Extracted from context.tsx so tests can import the reducer without + * pulling in the entire UI dependency graph. + */ + +import type { ParentThread } from "./types"; + +/** + * Chat state — shared across the Decopilot chat provider. + * + * NOTE: tiptapDoc is intentionally NOT here — it lives as local state in + * ChatInput to avoid re-rendering the entire context tree on every keystroke. + */ +export interface ChatState { + /** Active parent thread if branching is in progress */ + parentThread: ParentThread | null; + /** Finish reason from the last chat completion */ + finishReason: string | null; +} + +/** + * Actions for the chat state reducer + */ +export type ChatStateAction = + | { type: "START_BRANCH"; payload: ParentThread } + | { type: "CLEAR_BRANCH" } + | { type: "SET_FINISH_REASON"; payload: string | null } + | { type: "CLEAR_FINISH_REASON" } + | { type: "RESET" }; + +/** + * Initial chat state + */ +export const initialChatState: ChatState = { + parentThread: null, + finishReason: null, +}; + +/** + * Reducer for chat state + */ +export function chatStateReducer( + state: ChatState, + action: ChatStateAction, +): ChatState { + switch (action.type) { + case "START_BRANCH": + return { ...state, parentThread: action.payload }; + case "CLEAR_BRANCH": + return { ...state, parentThread: null }; + case "SET_FINISH_REASON": + if (state.finishReason === action.payload) return state; + return { ...state, finishReason: action.payload }; + case "CLEAR_FINISH_REASON": + return { ...state, finishReason: null }; + case "RESET": + return initialChatState; + default: + return state; + } +} diff --git a/apps/mesh/src/web/components/chat/context.test.tsx b/apps/mesh/src/web/components/chat/context.test.tsx index 5dd61737d2..59dcdd7733 100644 --- a/apps/mesh/src/web/components/chat/context.test.tsx +++ b/apps/mesh/src/web/components/chat/context.test.tsx @@ -2,76 +2,28 @@ * Tests for ChatState Reducer * * Tests the reducer logic for the chat state management. + * NOTE: tiptapDoc was moved out of the reducer into ChatInput local state. */ import { describe, expect, test } from "bun:test"; import type { ParentThread } from "./types.ts"; -import type { ChatState, ChatStateAction } from "./context"; - -// Import the reducer directly for testing -// Since it's not exported, we'll test through the exported types -// In a real scenario, you might want to export the reducer for testing +import { + chatStateReducer, + type ChatState, + type ChatStateAction, +} from "./chat-state"; describe("ChatState Reducer Logic", () => { const initialState: ChatState = { - tiptapDoc: undefined, parentThread: null, finishReason: null, }; - // Helper to simulate reducer behavior - function applyAction(state: ChatState, action: ChatStateAction): ChatState { - switch (action.type) { - case "SET_TIPTAP_DOC": - return { ...state, tiptapDoc: action.payload }; - case "CLEAR_TIPTAP_DOC": - return { ...state, tiptapDoc: undefined }; - case "START_BRANCH": - return { ...state, parentThread: action.payload }; - case "CLEAR_BRANCH": - return { ...state, parentThread: null }; - case "SET_FINISH_REASON": - return { ...state, finishReason: action.payload }; - case "CLEAR_FINISH_REASON": - return { ...state, finishReason: null }; - case "RESET": - return { - tiptapDoc: undefined, - parentThread: null, - finishReason: null, - }; - default: - return state; - } - } - test("should initialize with empty state", () => { - expect(initialState.tiptapDoc).toBeUndefined(); expect(initialState.parentThread).toBeNull(); expect(initialState.finishReason).toBeNull(); }); - test("should update tiptap doc with SET_TIPTAP_DOC action", () => { - const doc = { - type: "doc" as const, - content: [ - { - type: "paragraph", - content: [{ type: "text", text: "Hello, world!" }], - }, - ], - }; - const action: ChatStateAction = { - type: "SET_TIPTAP_DOC", - payload: doc, - }; - - const newState = applyAction(initialState, action); - - expect(newState.tiptapDoc).toEqual(doc); - expect(newState.parentThread).toBeNull(); - }); - test("should start branch with START_BRANCH action", () => { const parentThread: ParentThread = { thread_id: "thread-123", @@ -83,23 +35,13 @@ describe("ChatState Reducer Logic", () => { payload: parentThread, }; - const newState = applyAction(initialState, action); + const newState = chatStateReducer(initialState, action); expect(newState.parentThread).toEqual(parentThread); - expect(newState.tiptapDoc).toBeUndefined(); }); test("should clear branch context with CLEAR_BRANCH action", () => { const stateWithBranch: ChatState = { - tiptapDoc: { - type: "doc", - content: [ - { - type: "paragraph", - content: [{ type: "text", text: "Some input" }], - }, - ], - }, parentThread: { thread_id: "thread-123", messageId: "msg-456", @@ -109,10 +51,9 @@ describe("ChatState Reducer Logic", () => { const action: ChatStateAction = { type: "CLEAR_BRANCH" }; - const newState = applyAction(stateWithBranch, action); + const newState = chatStateReducer(stateWithBranch, action); expect(newState.parentThread).toBeNull(); - expect(newState.tiptapDoc).toEqual(stateWithBranch.tiptapDoc); // Tiptap doc should remain }); test("should set finish reason with SET_FINISH_REASON action", () => { @@ -121,37 +62,43 @@ describe("ChatState Reducer Logic", () => { payload: "stop", }; - const newState = applyAction(initialState, action); + const newState = chatStateReducer(initialState, action); expect(newState.finishReason).toBe("stop"); expect(newState.parentThread).toBeNull(); }); + test("SET_FINISH_REASON returns same reference when payload is unchanged", () => { + const stateWithReason: ChatState = { + parentThread: null, + finishReason: "stop", + }; + + const action: ChatStateAction = { + type: "SET_FINISH_REASON", + payload: "stop", + }; + + const newState = chatStateReducer(stateWithReason, action); + + expect(newState).toBe(stateWithReason); + }); + test("should clear finish reason with CLEAR_FINISH_REASON action", () => { const stateWithFinishReason: ChatState = { - tiptapDoc: undefined, parentThread: null, finishReason: "stop", }; const action: ChatStateAction = { type: "CLEAR_FINISH_REASON" }; - const newState = applyAction(stateWithFinishReason, action); + const newState = chatStateReducer(stateWithFinishReason, action); expect(newState.finishReason).toBeNull(); }); test("should reset all state with RESET action", () => { const stateWithData: ChatState = { - tiptapDoc: { - type: "doc", - content: [ - { - type: "paragraph", - content: [{ type: "text", text: "Test input" }], - }, - ], - }, parentThread: { thread_id: "thread-123", messageId: "msg-456", @@ -161,9 +108,8 @@ describe("ChatState Reducer Logic", () => { const action: ChatStateAction = { type: "RESET" }; - const newState = applyAction(stateWithData, action); + const newState = chatStateReducer(stateWithData, action); - expect(newState.tiptapDoc).toBeUndefined(); expect(newState.parentThread).toBeNull(); expect(newState.finishReason).toBeNull(); }); @@ -171,107 +117,46 @@ describe("ChatState Reducer Logic", () => { test("should handle multiple sequential actions", () => { let state = initialState; - // Set tiptap doc - const doc1 = { - type: "doc" as const, - content: [ - { - type: "paragraph", - content: [{ type: "text", text: "First message" }], - }, - ], - }; - state = applyAction(state, { type: "SET_TIPTAP_DOC", payload: doc1 }); - expect(state.tiptapDoc).toEqual(doc1); - - // Start branch const parentThread: ParentThread = { thread_id: "thread-1", messageId: "msg-1", }; - state = applyAction(state, { + state = chatStateReducer(state, { type: "START_BRANCH", payload: parentThread, }); expect(state.parentThread).toEqual(parentThread); - expect(state.tiptapDoc).toEqual(doc1); // Doc persists - - // Update tiptap doc again - const doc2 = { - type: "doc" as const, - content: [ - { - type: "paragraph", - content: [{ type: "text", text: "Updated message" }], - }, - ], - }; - state = applyAction(state, { - type: "SET_TIPTAP_DOC", - payload: doc2, - }); - expect(state.tiptapDoc).toEqual(doc2); - expect(state.parentThread).toEqual(parentThread); // Branch persists - // Clear branch - state = applyAction(state, { type: "CLEAR_BRANCH" }); + state = chatStateReducer(state, { type: "CLEAR_BRANCH" }); expect(state.parentThread).toBeNull(); - expect(state.tiptapDoc).toEqual(doc2); // Doc still there - // Reset all - state = applyAction(state, { type: "RESET" }); - expect(state.tiptapDoc).toBeUndefined(); + state = chatStateReducer(state, { + type: "SET_FINISH_REASON", + payload: "stop", + }); + expect(state.finishReason).toBe("stop"); + + state = chatStateReducer(state, { type: "RESET" }); expect(state.parentThread).toBeNull(); + expect(state.finishReason).toBeNull(); }); test("should preserve state immutability", () => { - const originalDoc = { - type: "doc" as const, - content: [ - { type: "paragraph", content: [{ type: "text", text: "Original" }] }, - ], - }; - const originalState: ChatState = { - tiptapDoc: originalDoc, - parentThread: null, - finishReason: null, - }; - - const modifiedDoc = { - type: "doc" as const, - content: [ - { type: "paragraph", content: [{ type: "text", text: "Modified" }] }, - ], - }; - const action: ChatStateAction = { - type: "SET_TIPTAP_DOC", - payload: modifiedDoc, - }; - - const newState = applyAction(originalState, action); - - // Original state should not be modified - expect(originalState.tiptapDoc).toEqual(originalDoc); - expect(newState.tiptapDoc).toEqual(modifiedDoc); - expect(newState).not.toBe(originalState); - }); - - test("should handle branch context immutability", () => { const originalParentThread: ParentThread = { thread_id: "thread-1", messageId: "msg-1", }; - const stateWithBranch: ChatState = { - tiptapDoc: undefined, + const originalState: ChatState = { parentThread: originalParentThread, finishReason: null, }; - const newState = applyAction(stateWithBranch, { type: "CLEAR_BRANCH" }); + const newState = chatStateReducer(originalState, { type: "CLEAR_BRANCH" }); - // Original branch object should not be modified expect(originalParentThread.thread_id).toBe("thread-1"); + expect(originalState.parentThread).toEqual(originalParentThread); expect(newState.parentThread).toBeNull(); + expect(newState).not.toBe(originalState); }); }); diff --git a/apps/mesh/src/web/components/chat/context.tsx b/apps/mesh/src/web/components/chat/context.tsx index d8fb385ce0..4d2c14b9b8 100644 --- a/apps/mesh/src/web/components/chat/context.tsx +++ b/apps/mesh/src/web/components/chat/context.tsx @@ -29,9 +29,13 @@ import { type PropsWithChildren, Suspense, useContext, + useDeferredValue, useEffect, useReducer, + useRef, } from "react"; +import { useDecopilotEvents } from "../../hooks/use-decopilot-events"; +import { useQueryClient } from "@tanstack/react-query"; import { toast } from "sonner"; import { useModelConnections } from "../../hooks/collections/use-llm"; import { useAllowedModels } from "../../hooks/use-allowed-models"; @@ -42,6 +46,7 @@ import { ErrorBoundary } from "../error-boundary"; import { useNotification } from "../../hooks/use-notification"; import { usePreferences } from "../../hooks/use-preferences"; import { authClient } from "../../lib/auth-client"; +import { KEYS } from "../../lib/query-keys"; import { LOCALSTORAGE_KEYS } from "../../lib/localstorage-keys"; import { type ModelChangePayload, useModels } from "./select-model"; import type { VirtualMCPInfo } from "./select-virtual-mcp"; @@ -51,37 +56,19 @@ import type { ChatMessage, ChatModelsConfig, Metadata, - ParentThread, Thread, } from "./types.ts"; - +import { + chatStateReducer, + initialChatState, + type ChatState, + type ChatStateAction, +} from "./chat-state"; // ============================================================================ // Type Definitions // ============================================================================ -/** - * State shape for chat state (reducer-managed) - */ -export interface ChatState { - /** Tiptap document representing the current input (source of truth) */ - tiptapDoc: Metadata["tiptapDoc"]; - /** Active parent thread if branching is in progress */ - parentThread: ParentThread | null; - /** Finish reason from the last chat completion */ - finishReason: string | null; -} - -/** - * Actions for the chat state reducer - */ -export type ChatStateAction = - | { type: "SET_TIPTAP_DOC"; payload: Metadata["tiptapDoc"] } - | { type: "CLEAR_TIPTAP_DOC" } - | { type: "START_BRANCH"; payload: ParentThread } - | { type: "CLEAR_BRANCH" } - | { type: "SET_FINISH_REASON"; payload: string | null } - | { type: "CLEAR_FINISH_REASON" } - | { type: "RESET" }; +export type { ChatState, ChatStateAction }; /** * Shape persisted in localStorage for the selected model. @@ -111,49 +98,55 @@ type ChatFromUseChat = Pick< >; /** - * Combined context value including interaction state, thread management, and session state + * Stable context — values that change infrequently (model/thread/mode selection, actions). + * Consumers reading only stable fields skip re-renders during streaming. */ -interface ChatContextValue extends ChatFromUseChat { - // Interaction state - tiptapDoc: Metadata["tiptapDoc"]; - setTiptapDoc: (doc: Metadata["tiptapDoc"]) => void; - clearTiptapDoc: () => void; +interface ChatStableValue { + tiptapDocRef: React.RefObject; resetInteraction: () => void; - // Thread management activeThreadId: string; - createThread: () => void; // For creating new threads (with prefetch) - switchToThread: (threadId: string) => Promise; // For switching with cache prefilling + createThread: () => void; + switchToThread: (threadId: string) => Promise; threads: Thread[]; hideThread: (threadId: string) => void; - // Thread pagination (for infinite scroll) hasNextPage?: boolean; isFetchingNextPage?: boolean; fetchNextPage?: () => void; - // Virtual MCP state virtualMcps: VirtualMCPInfo[]; selectedVirtualMcp: VirtualMCPInfo | null; setVirtualMcpId: (virtualMcpId: string | null) => void; - // Model state modelsConnections: ReturnType; selectedModel: ChatModelsConfig | null; setSelectedModel: (model: ModelChangePayload) => void; - // Mode state selectedMode: ToolSelectionStrategy; setSelectedMode: (mode: ToolSelectionStrategy) => void; - // Chat state (extends useChat; sendMessage overridden, isStreaming/isChatEmpty derived) sendMessage: (tiptapDoc: Metadata["tiptapDoc"]) => Promise; + cancelRun: () => Promise; +} + +/** + * Stream context — values that change per chunk or stream lifecycle event. + * Messages are deferred via useDeferredValue so React skips intermediate renders. + */ +interface ChatStreamValue extends ChatFromUseChat { isStreaming: boolean; isChatEmpty: boolean; finishReason: string | null; clearFinishReason: () => void; + /** Derived from chat.messages (AI SDK state) to avoid stale reads during message source switches */ + isWaitingForApprovals: boolean; + /** True when thread is in_progress but we have no active local stream */ + isRunInProgress: boolean; } +type ChatContextValue = ChatStableValue & ChatStreamValue; + // ============================================================================ // Implementation // ============================================================================ @@ -164,6 +157,9 @@ const createModelsTransport = ( new DefaultChatTransport>({ api: `/api/${org}/decopilot/stream`, credentials: "include", + prepareReconnectToStreamRequest: ({ id }) => ({ + api: `/api/${org}/decopilot/attach/${id}`, + }), prepareSendMessagesRequest: ({ messages, requestMetadata = {} }) => { const { system, @@ -262,42 +258,6 @@ const useModelState = ( return [selectedModelsConfig, setModelState] as const; }; -/** - * Initial chat state - */ -const initialChatState: ChatState = { - tiptapDoc: undefined, - parentThread: null, - finishReason: null, -}; - -/** - * Reducer for chat state - */ -function chatStateReducer( - state: ChatState, - action: ChatStateAction, -): ChatState { - switch (action.type) { - case "SET_TIPTAP_DOC": - return { ...state, tiptapDoc: action.payload }; - case "CLEAR_TIPTAP_DOC": - return { ...state, tiptapDoc: undefined }; - case "START_BRANCH": - return { ...state, parentThread: action.payload }; - case "CLEAR_BRANCH": - return { ...state, parentThread: null }; - case "SET_FINISH_REASON": - return { ...state, finishReason: action.payload }; - case "CLEAR_FINISH_REASON": - return { ...state, finishReason: null }; - case "RESET": - return initialChatState; - default: - return state; - } -} - /** * Converts resource contents to UI message parts */ @@ -523,7 +483,8 @@ function derivePartsFromTiptapDoc( return parts; } -const ChatContext = createContext(null); +const ChatStableContext = createContext(null); +const ChatStreamContext = createContext(null); /** * Silent child component that auto-selects the first available model when @@ -596,6 +557,7 @@ export function ChatProvider({ children }: PropsWithChildren) { // =========================================================================== const { locator, org } = useProjectContext(); + const queryClient = useQueryClient(); // Unified thread manager hook handles all thread state and operations const threadManager = useThreadManager(); @@ -612,6 +574,9 @@ export function ChatProvider({ children }: PropsWithChildren) { initialChatState, ); + // Shared ref for tiptapDoc — ChatInput owns the state, others read the ref. + const tiptapDocRef = useRef(undefined); + // Virtual MCP state const virtualMcps = useVirtualMCPs(); const [storedSelectedVirtualMcpId, setSelectedVirtualMcpId] = useLocalStorage< @@ -679,33 +644,41 @@ export function ChatProvider({ children }: PropsWithChildren) { }) => { chatDispatch({ type: "SET_FINISH_REASON", payload: finishReason ?? null }); + const threadId = + (message.metadata as Metadata | undefined)?.thread_id ?? + threadManager.activeThreadId; + if (isAbort || isDisconnect || isError) { + // Persist partial messages so the UI doesn't flash back to stale + // server data when the message source switches from chat.messages + // to threadManager.messages (isStreaming -> false). + if (threadId && messages.length > 0) { + threadManager.updateMessagesCache(threadId, messages); + } return; } - const { thread_id } = message.metadata ?? {}; - - if (!thread_id) { + if (!threadId) { return; } + // Always persist streamed messages into the thread cache so the UI + // doesn't flash stale data when the message source switches from + // chat.messages (streaming) to threadManager.messages (server). + if (messages.length > 0) { + threadManager.updateMessagesCache(threadId, messages); + } + // Show notification (sound + browser popup) if enabled if (preferences.enableNotifications) { showNotification({ - tag: `chat-${thread_id}`, + tag: `chat-${threadId}`, title: "Decopilot is waiting for your input at", body: - threadManager.threads.find((t) => t.id === thread_id)?.title ?? + threadManager.threads.find((t) => t.id === threadId)?.title ?? "New chat", }); } - - if (finishReason !== "stop") { - return; - } - - // Update messages cache with the latest messages from the stream - threadManager.updateMessagesCache(thread_id, messages); }; const onError = (error: Error) => { @@ -749,7 +722,131 @@ export function ChatProvider({ children }: PropsWithChildren) { const isStreaming = chat.status === "submitted" || chat.status === "streaming"; - const isChatEmpty = chat.messages.length === 0; + // Computed from chat.messages (AI SDK's stable internal state) rather than + // the source-switched `messages` which briefly becomes stale between + // auto-send cycles, causing the warning banner to flicker. + const isWaitingForApprovals = (() => { + const last = chat.messages.at(-1); + if (!last || last.role !== "assistant") return false; + return last.parts.some( + (part) => "state" in part && part.state === "approval-requested", + ); + })(); + + const isChatEmpty = + chat.messages.length === 0 && threadManager.messages.length === 0; + + const activeThread = threadManager.threads.find( + (t) => t.id === threadManager.activeThreadId, + ); + const isRunInProgress = + (activeThread?.status === "in_progress" || + activeThread?.status === "expired") && + !isStreaming; + + // Ref so the SSE subscription handler can call resumeStream without + // being re-created when `chat` changes (avoids unstable closure deps). + const chatRef = useRef(chat); + chatRef.current = chat; + const hasResumedRef = useRef(null); + const resumeFailCountRef = useRef(0); + const MAX_RESUME_RETRIES = 3; + + const invalidateThreadData = () => { + queryClient.invalidateQueries({ queryKey: KEYS.threads(locator) }); + const tid = threadManager.activeThreadId; + if (tid) { + queryClient.invalidateQueries({ + predicate: (query) => { + const key = query.queryKey; + if (key[3] !== "collection" || key[4] !== "THREAD_MESSAGES") { + return false; + } + const serialized = typeof key[6] === "string" ? key[6] : ""; + return serialized.includes(tid); + }, + }); + } + }; + + // Resume an in-progress stream via the AI SDK's transport.reconnectToStream + // (GET /attach/:threadId → JetStream replay). The SDK handles all internal + // message state: status flips to "streaming", chat.messages updates live. + const tryResumeStream = (reason: string) => { + const tid = threadManager.activeThreadId; + if (!tid || hasResumedRef.current === tid) return; + if (resumeFailCountRef.current >= MAX_RESUME_RETRIES) return; + hasResumedRef.current = tid; + + console.log(`[chat] resumeStream (${reason})`, tid); + chatRef.current.resumeStream().catch((err: unknown) => { + console.error("[chat] resumeStream error", err); + resumeFailCountRef.current++; + hasResumedRef.current = null; + invalidateThreadData(); + }); + }; + + const invalidateThreadDataRef = useRef(invalidateThreadData); + invalidateThreadDataRef.current = invalidateThreadData; + + const tryResumeStreamRef = useRef(tryResumeStream); + tryResumeStreamRef.current = tryResumeStream; + + useDecopilotEvents({ + orgId: org.id, + threadId: threadManager.activeThreadId, + onStep: () => tryResumeStream("sse-step"), + onFinish: () => { + hasResumedRef.current = null; + resumeFailCountRef.current = 0; + if (!isStreaming) { + invalidateThreadData(); + } + }, + onThreadStatus: () => { + if (!isStreaming) { + invalidateThreadData(); + } + }, + }); + + // Reset resume state when switching threads so failures from one thread + // don't block resume attempts on a different thread. + // Done during render (not in useEffect) to avoid React strict-mode + // double-mount resetting the guard and firing duplicate attach requests. + const prevActiveThreadIdRef = useRef(threadManager.activeThreadId); + if (prevActiveThreadIdRef.current !== threadManager.activeThreadId) { + prevActiveThreadIdRef.current = threadManager.activeThreadId; + hasResumedRef.current = null; + resumeFailCountRef.current = 0; + } + + // Trigger resume on page load / thread switch when a background run is active. + // Also safety-net poll in case SSE events are missed (NATS at-most-once). + const SAFETY_NET_POLL_MS = 30_000; + // oxlint-disable-next-line ban-use-effect/ban-use-effect + useEffect(() => { + if (!isRunInProgress) return; + + tryResumeStreamRef.current("page-load"); + + invalidateThreadDataRef.current(); + const safetyId = setInterval( + () => invalidateThreadDataRef.current(), + SAFETY_NET_POLL_MS, + ); + + return () => { + clearInterval(safetyId); + }; + }, [isRunInProgress]); + + // Show real-time chat.messages during active streaming (local or resumed); + // otherwise use server-sourced threadManager.messages. + const messages = isStreaming + ? chat.messages + : (threadManager.messages as ChatMessage[]); // =========================================================================== // 6. RETURNED FUNCTIONS - Functions exposed via context @@ -764,11 +861,6 @@ export function ChatProvider({ children }: PropsWithChildren) { const hideThread = threadManager.hideThread; // Chat state functions - const setTiptapDoc = (doc: Metadata["tiptapDoc"]) => - chatDispatch({ type: "SET_TIPTAP_DOC", payload: doc }); - - const clearTiptapDoc = () => chatDispatch({ type: "CLEAR_TIPTAP_DOC" }); - const resetInteraction = () => chatDispatch({ type: "RESET" }); // Virtual MCP functions @@ -800,6 +892,12 @@ export function ChatProvider({ children }: PropsWithChildren) { return; } + // Sync server-sourced messages into useAIChat before sending so its + // internal state is current (needed for onFinish cache write-back and + // sendAutomaticallyWhen checks on the response). + if (threadManager.messages.length > 0) { + chatRef.current.setMessages(threadManager.messages); + } resetInteraction(); const messageMetadata: Metadata = { @@ -830,10 +928,55 @@ export function ChatProvider({ children }: PropsWithChildren) { metadata: messageMetadata, }; - await chat.sendMessage(userMessage, { metadata }); + await chatRef.current.sendMessage(userMessage, { metadata }); }; - const stop = () => chat.stop(); + const cancelRun = async () => { + const threadId = threadManager.activeThreadId; + if (!threadId) return; + hasResumedRef.current = null; + resumeFailCountRef.current = 0; + + // Snapshot streaming messages into the thread cache BEFORE stopping. + // When chat.stop() fires, isStreaming flips to false and the UI switches + // from chat.messages to threadManager.messages — this preserves the + // partial content generated up to the abort point. + if (chatRef.current.messages.length > 0) { + threadManager.updateMessagesCache(threadId, chatRef.current.messages); + } + + chatRef.current.stop(); + try { + const res = await fetch(`/api/${org.slug}/decopilot/cancel/${threadId}`, { + method: "POST", + credentials: "include", + }); + if (!res.ok) { + const data = (await res.json().catch(() => ({}))) as { + message?: string; + }; + throw new Error(data.message ?? `Cancel failed: ${res.status}`); + } + await queryClient.invalidateQueries({ queryKey: KEYS.threads(locator) }); + } catch (err) { + const msg = err instanceof Error ? err.message : "Failed to cancel"; + toast.error(msg); + console.error("[chat] cancelRun", err); + } + }; + + const stop = (): void => { + if (isStreaming) { + void cancelRun(); + } + chat.stop(); + }; + + // Wrap for context: UseChatHelpers may expect () => Promise + const stopForContext = (): Promise => { + stop(); + return Promise.resolve(); + }; const clearFinishReason = () => chatDispatch({ type: "CLEAR_FINISH_REASON" }); @@ -841,83 +984,96 @@ export function ChatProvider({ children }: PropsWithChildren) { // 7. CONTEXT VALUE & RETURN // =========================================================================== - const value: ChatContextValue = { - // Chat state - tiptapDoc: chatState.tiptapDoc, - setTiptapDoc, - clearTiptapDoc, - resetInteraction, + const deferredMessages = useDeferredValue(messages); - // Thread management (using threadManager) + const stableValue: ChatStableValue = { + tiptapDocRef, + resetInteraction, activeThreadId: threadManager.activeThreadId, threads: threadManager.threads, createThread, switchToThread, hideThread, - - // Thread pagination hasNextPage: threadManager.hasNextPage, isFetchingNextPage: threadManager.isFetchingNextPage, fetchNextPage: threadManager.fetchNextPage, - - // Virtual MCP state virtualMcps, selectedVirtualMcp, setVirtualMcpId, - - // Model state modelsConnections, selectedModel, setSelectedModel, - - // Mode state selectedMode, setSelectedMode, + sendMessage, + cancelRun, + }; - // Chat session state (from useChat) - messages: chat.messages, + const streamValue: ChatStreamValue = { + messages: deferredMessages, status: chat.status, setMessages: chat.setMessages, error: chat.error, clearError: chat.clearError, - stop, + stop: stopForContext, addToolOutput: chat.addToolOutput, addToolApprovalResponse: chat.addToolApprovalResponse, - sendMessage, isStreaming, isChatEmpty, finishReason: chatState.finishReason, clearFinishReason, + isWaitingForApprovals, + isRunInProgress, }; return ( - - {/* Auto-selects first model when none is stored. - ErrorBoundary ensures MCP errors (e.g. auth failures) never crash the provider. */} - - - - - - {children} - + + + + + + + + {children} + + ); } /** - * Hook to access the full chat context - * Returns interaction state, thread management, virtual MCP, model, and chat session state + * Stable chat values (model, mode, threads, virtual MCP, actions). + * Does NOT re-render during streaming. + */ +export function useChatStable() { + const context = useContext(ChatStableContext); + if (!context) { + throw new Error("useChatStable must be used within a ChatProvider"); + } + return context; +} + +/** + * Streaming chat values (messages, status, error, derived booleans). + * Re-renders during streaming with deferred batching. */ -export function useChat() { - const context = useContext(ChatContext); +function useChatStream() { + const context = useContext(ChatStreamContext); if (!context) { - throw new Error("useChat must be used within a ChatProvider"); + throw new Error("useChatStream must be used within a ChatProvider"); } return context; } + +/** + * Full chat context (stable + stream merged). + * Prefer useChatStable() or useChatStream() to reduce re-renders during streaming. + */ +export function useChat(): ChatContextValue { + return { ...useChatStable(), ...useChatStream() }; +} diff --git a/apps/mesh/src/web/components/chat/highlight/index.tsx b/apps/mesh/src/web/components/chat/highlight/index.tsx index cf31bc074c..166001f7cd 100644 --- a/apps/mesh/src/web/components/chat/highlight/index.tsx +++ b/apps/mesh/src/web/components/chat/highlight/index.tsx @@ -124,6 +124,7 @@ export function ChatHighlight() { clearFinishReason, messages, isStreaming, + isWaitingForApprovals, addToolOutput, sendMessage, } = useChat(); @@ -139,14 +140,6 @@ export function ChatHighlight() { (p) => p.state !== "output-available", )?.length; - // Check if any tools are awaiting approval - const isWaitingForApprovals = - lastMessage?.role === "assistant" - ? lastMessage.parts.some( - (part) => "state" in part && part.state === "approval-requested", - ) - : false; - const handleFixInChat = () => { if (error) { const text = `I encountered this error: ${error.message}. Can you help me fix it?`; diff --git a/apps/mesh/src/web/components/chat/ice-breakers.tsx b/apps/mesh/src/web/components/chat/ice-breakers.tsx index 257d4f3545..bde1ac7e89 100644 --- a/apps/mesh/src/web/components/chat/ice-breakers.tsx +++ b/apps/mesh/src/web/components/chat/ice-breakers.tsx @@ -23,7 +23,7 @@ import type { Prompt } from "@modelcontextprotocol/sdk/types.js"; import { Suspense, useReducer, useState } from "react"; import { toast } from "sonner"; import { ErrorBoundary } from "../error-boundary"; -import { useChat } from "./context"; +import { useChatStable } from "./context"; import { PromptArgsDialog, type PromptArgumentValues, @@ -238,7 +238,7 @@ function iceBreakerReducer( * @param connectionId - The connection ID, or null for the management MCP */ function IceBreakersContent({ connectionId }: { connectionId: string | null }) { - const { tiptapDoc, sendMessage } = useChat(); + const { tiptapDocRef, sendMessage } = useChatStable(); const { org } = useProjectContext(); const client = useMCPClient({ connectionId, @@ -263,7 +263,7 @@ function IceBreakersContent({ connectionId }: { connectionId: string | null }) { // Append prompt to current tiptapDoc and send // Wrap mention in a paragraph since it's an inline node - const newTiptapDoc = appendToTiptapDoc(tiptapDoc, { + const newTiptapDoc = appendToTiptapDoc(tiptapDocRef.current, { type: "paragraph", content: [ createMentionDoc({ @@ -337,7 +337,7 @@ function IceBreakersContent({ connectionId }: { connectionId: string | null }) { * Includes ErrorBoundary, Suspense, and container internally. */ export function IceBreakers({ className }: IceBreakersProps) { - const { selectedVirtualMcp } = useChat(); + const { selectedVirtualMcp } = useChatStable(); // When selectedVirtualMcp is null, use decopilot ID (default agent) const { org } = useProjectContext(); const decopilotId = getWellKnownDecopilotVirtualMCP(org.id).id; diff --git a/apps/mesh/src/web/components/chat/input.tsx b/apps/mesh/src/web/components/chat/input.tsx index e59d1e86a1..fb371a0b8b 100644 --- a/apps/mesh/src/web/components/chat/input.tsx +++ b/apps/mesh/src/web/components/chat/input.tsx @@ -30,6 +30,7 @@ import { } from "@untitledui/icons"; import type { FormEvent } from "react"; import { useEffect, useRef, useState, type MouseEvent } from "react"; +import type { Metadata } from "./types.ts"; import { useChat } from "./context"; import { ChatHighlight } from "./highlight"; import { ModeSelector } from "./select-mode"; @@ -271,8 +272,7 @@ function VirtualMCPBadge({ export function ChatInput() { const { activeThreadId, - tiptapDoc, - setTiptapDoc, + tiptapDocRef, virtualMcps, selectedVirtualMcp, setVirtualMcpId, @@ -283,10 +283,30 @@ export function ChatInput() { setSelectedMode, messages, isStreaming, + isRunInProgress, sendMessage, stop, + cancelRun, } = useChat(); + // tiptapDoc lives here (not in context) so keystrokes don't re-render + // the entire context tree. The ref on context lets IceBreakers read it. + const [tiptapDoc, setTiptapDocLocal] = + useState(undefined); + + const setTiptapDoc = (doc: Metadata["tiptapDoc"]) => { + setTiptapDocLocal(doc); + tiptapDocRef.current = doc; + }; + + // Reset input when switching threads (TiptapProvider also remounts via key) + const prevThreadRef = useRef(activeThreadId); + if (prevThreadRef.current !== activeThreadId) { + prevThreadRef.current = activeThreadId; + setTiptapDocLocal(undefined); + tiptapDocRef.current = undefined; + } + const contextWindow = selectedModel?.thinking.limits?.contextWindow; const tiptapRef = useRef(null); @@ -301,12 +321,17 @@ export function ChatInput() { const canSubmit = !isStreaming && !!selectedModel && !isTiptapDocEmpty(tiptapDoc); + const showStopOrCancel = isStreaming || isRunInProgress; + const handleSubmit = (e?: FormEvent) => { e?.preventDefault(); if (isStreaming) { stop(); + } else if (isRunInProgress) { + void cancelRun(); } else if (canSubmit && tiptapDoc) { void sendMessage(tiptapDoc); + setTiptapDoc(undefined); } }; @@ -368,6 +393,11 @@ export function ChatInput() {
{/* Left Actions (agent selector and usage stats) */}
+ {isRunInProgress && ( + + Run in progress + + )} {/* Always show selector button - DecopilotIconButton for Decopilot, VirtualMCPSelector for others */} {selectedVirtualMcp && isDecopilot(selectedVirtualMcp.id) ? (
diff --git a/apps/mesh/src/web/components/chat/message/parts/tool-call-part/subtask.tsx b/apps/mesh/src/web/components/chat/message/parts/tool-call-part/subtask.tsx index 96d972aadd..c8b2834df0 100644 --- a/apps/mesh/src/web/components/chat/message/parts/tool-call-part/subtask.tsx +++ b/apps/mesh/src/web/components/chat/message/parts/tool-call-part/subtask.tsx @@ -4,7 +4,7 @@ import type { ToolSubtaskMetadata } from "../../use-filter-parts.ts"; import { IntegrationIcon } from "@/web/components/integration-icon"; import type { ToolDefinition } from "@decocms/mesh-sdk"; import { Users03 } from "@untitledui/icons"; -import { useChat } from "../../../context.tsx"; +import { useChatStable } from "../../../context.tsx"; import type { SubtaskToolPart } from "../../../types.ts"; import { extractTextFromOutput, getToolPartErrorText } from "../utils.ts"; import { ToolCallShell } from "./common.tsx"; @@ -27,7 +27,7 @@ export function SubtaskPart({ annotations, latency, }: SubtaskPartProps) { - const { virtualMcps } = useChat(); + const { virtualMcps } = useChatStable(); // State computation const isInputStreaming = diff --git a/apps/mesh/src/web/components/chat/popover-threads.tsx b/apps/mesh/src/web/components/chat/popover-threads.tsx index 5a1dc1951f..d9c7fcd43d 100644 --- a/apps/mesh/src/web/components/chat/popover-threads.tsx +++ b/apps/mesh/src/web/components/chat/popover-threads.tsx @@ -13,7 +13,7 @@ import { import { cn } from "@deco/ui/lib/utils.ts"; import { Clock, SearchMd, Trash01 } from "@untitledui/icons"; import { useRef, useState } from "react"; -import { useChat } from "./context"; +import { useChatStable } from "./context"; import type { Thread } from "./types.ts"; type ThreadSection = { @@ -116,7 +116,7 @@ export function ThreadHistoryPopover({ isFetchingNextPage, fetchNextPage, hideThread, - } = useChat(); + } = useChatStable(); const sentinelRef = useRef(null); // Set up intersection observer for infinite scroll diff --git a/apps/mesh/src/web/components/chat/thread/cache-operations.ts b/apps/mesh/src/web/components/chat/thread/cache-operations.ts index cce02cd822..d6af8cca8e 100644 --- a/apps/mesh/src/web/components/chat/thread/cache-operations.ts +++ b/apps/mesh/src/web/components/chat/thread/cache-operations.ts @@ -54,6 +54,7 @@ export function updateThreadInCache( created_at: currentThread.created_at, updated_at: updates.updated_at ?? currentThread.updated_at, hidden: updates.hidden ?? currentThread.hidden, + status: updates.status ?? currentThread.status, }; updatedItems[threadIndex] = updatedThread; diff --git a/apps/mesh/src/web/components/chat/thread/types.ts b/apps/mesh/src/web/components/chat/thread/types.ts index c3d6b5c684..463606404b 100644 --- a/apps/mesh/src/web/components/chat/thread/types.ts +++ b/apps/mesh/src/web/components/chat/thread/types.ts @@ -1,3 +1,5 @@ +import type { ThreadDisplayStatus } from "@decocms/mesh-sdk"; + // Constants export const THREAD_CONSTANTS = { /** Page size for thread messages queries */ @@ -15,6 +17,8 @@ export interface Thread { created_at: string; // ISO string updated_at: string; // ISO string hidden?: boolean; + /** Execution status from server — includes virtual "expired" for stale in_progress threads */ + status?: ThreadDisplayStatus; } export type { ChatMessage } from "../types.ts"; diff --git a/apps/mesh/src/web/components/details/virtual-mcp/index.tsx b/apps/mesh/src/web/components/details/virtual-mcp/index.tsx index 4cc4da7f40..d133014945 100644 --- a/apps/mesh/src/web/components/details/virtual-mcp/index.tsx +++ b/apps/mesh/src/web/components/details/virtual-mcp/index.tsx @@ -1,5 +1,5 @@ import type { VirtualMCPEntity } from "@/tools/virtual/schema"; -import { useChat } from "@/web/components/chat/context"; +import { useChatStable } from "@/web/components/chat/context"; import { EmptyState } from "@/web/components/empty-state.tsx"; import { ErrorBoundary } from "@/web/components/error-boundary"; import { IntegrationIcon } from "@/web/components/integration-icon.tsx"; @@ -216,7 +216,7 @@ function VirtualMcpDetailViewWithData({ // Auto-open chat with this agent selected const [, setChatOpen] = useDecoChatOpen(); - const { setVirtualMcpId } = useChat(); + const { setVirtualMcpId } = useChatStable(); // Open chat on mount (without selecting the agent) // oxlint-disable-next-line ban-use-effect/ban-use-effect diff --git a/apps/mesh/src/web/components/details/workflow/hooks/use-workflow-sse.ts b/apps/mesh/src/web/components/details/workflow/hooks/use-workflow-sse.ts index 0c67b51bc7..57bae8cf01 100644 --- a/apps/mesh/src/web/components/details/workflow/hooks/use-workflow-sse.ts +++ b/apps/mesh/src/web/components/details/workflow/hooks/use-workflow-sse.ts @@ -15,20 +15,23 @@ import { useSyncExternalStore } from "react"; import { useQueryClient, type QueryClient } from "@tanstack/react-query"; import { useProjectContext } from "@decocms/mesh-sdk"; +import { createSSESubscription } from "../../../../hooks/create-sse-subscription"; // ============================================================================ -// Shared EventSource per org (ref-counted) +// Shared connection pool // ============================================================================ -interface SharedConnection { - es: EventSource; - refCount: number; - queryClients: Set; - /** Pending debounce timer for coalescing invalidations */ - debounceTimer: ReturnType | null; -} +const WORKFLOW_EVENT_TYPES = [ + "workflow.execution.created", + "workflow.execution.resumed", + "workflow.step.execute", + "workflow.step.completed", +]; -const connections = new Map(); +const workflowSSE = createSSESubscription({ + buildUrl: (orgId) => `/org/${orgId}/watch?types=workflow.*`, + eventTypes: WORKFLOW_EVENT_TYPES, +}); /** Tool names whose query caches should be invalidated on workflow events */ const INVALIDATION_TARGETS = [ @@ -37,18 +40,16 @@ const INVALIDATION_TARGETS = [ "COLLECTION_WORKFLOW_EXECUTION_GET_STEP_RESULT", ]; -const WORKFLOW_EVENT_TYPES = [ - "workflow.execution.created", - "workflow.execution.resumed", - "workflow.step.execute", - "workflow.step.completed", -]; - /** Debounce window — coalesce rapid SSE events into one invalidation */ const DEBOUNCE_MS = 500; -function invalidateAllClients(conn: SharedConnection): void { - for (const client of conn.queryClients) { +const debounceTimers = new Map>(); +const queryClients = new Map>(); + +function invalidateAllClients(orgId: string): void { + const clients = queryClients.get(orgId); + if (!clients) return; + for (const client of clients) { client.invalidateQueries({ predicate: (query) => query.queryKey.some( @@ -58,47 +59,18 @@ function invalidateAllClients(conn: SharedConnection): void { } } -function scheduleInvalidation(conn: SharedConnection): void { - // If a timer is already pending, the upcoming flush will cover this event too - if (conn.debounceTimer !== null) return; +function scheduleInvalidation(orgId: string): void { + if (debounceTimers.has(orgId)) return; - conn.debounceTimer = setTimeout(() => { - conn.debounceTimer = null; - invalidateAllClients(conn); - }, DEBOUNCE_MS); + debounceTimers.set( + orgId, + setTimeout(() => { + debounceTimers.delete(orgId); + invalidateAllClients(orgId); + }, DEBOUNCE_MS), + ); } -function getOrCreateConnection(orgId: string): SharedConnection { - let conn = connections.get(orgId); - - if (!conn) { - const url = `/org/${orgId}/watch?types=workflow.*`; - const es = new EventSource(url); - - conn = { es, refCount: 0, queryClients: new Set(), debounceTimer: null }; - connections.set(orgId, conn); - - const onEvent = () => scheduleInvalidation(conn!); - - for (const eventType of WORKFLOW_EVENT_TYPES) { - es.addEventListener(eventType, onEvent); - } - - es.onerror = () => { - if (es.readyState === EventSource.CLOSED) { - if (conn!.debounceTimer !== null) { - clearTimeout(conn!.debounceTimer); - } - connections.delete(orgId); - } - }; - } - - return conn; -} - -// Snapshot is constant — we don't derive render state from SSE, -// we only use the subscription for its side-effect (query invalidation). const getSnapshot = () => 0; // ============================================================================ @@ -117,34 +89,33 @@ const getSnapshot = () => 0; export function useWorkflowSSE(): void { const { org } = useProjectContext(); const queryClient = useQueryClient(); - const orgId = org.id; const subscribe = (onStoreChange: () => void) => { - const conn = getOrCreateConnection(orgId); - conn.refCount++; - conn.queryClients.add(queryClient); - - // Attach per-subscriber handler so useSyncExternalStore can track changes - const handler = () => onStoreChange(); - for (const eventType of WORKFLOW_EVENT_TYPES) { - conn.es.addEventListener(eventType, handler); + let clients = queryClients.get(orgId); + if (!clients) { + clients = new Set(); + queryClients.set(orgId, clients); } + clients.add(queryClient); - return () => { - for (const eventType of WORKFLOW_EVENT_TYPES) { - conn.es.removeEventListener(eventType, handler); - } + const handler = () => { + scheduleInvalidation(orgId); + onStoreChange(); + }; - conn.queryClients.delete(queryClient); - conn.refCount--; + const unsubscribe = workflowSSE.subscribe(orgId, handler); - if (conn.refCount <= 0) { - if (conn.debounceTimer !== null) { - clearTimeout(conn.debounceTimer); - } - conn.es.close(); - connections.delete(orgId); + return () => { + unsubscribe(); + clients!.delete(queryClient); + if (clients!.size === 0) { + queryClients.delete(orgId); + } + const timer = debounceTimers.get(orgId); + if (timer && !queryClients.has(orgId)) { + clearTimeout(timer); + debounceTimers.delete(orgId); } }; }; diff --git a/apps/mesh/src/web/components/home/agents-list.tsx b/apps/mesh/src/web/components/home/agents-list.tsx index 2b42dc95a7..1aa4d2d069 100644 --- a/apps/mesh/src/web/components/home/agents-list.tsx +++ b/apps/mesh/src/web/components/home/agents-list.tsx @@ -5,7 +5,7 @@ * Only shows when the organization has agents. */ -import { useChat } from "@/web/components/chat/context"; +import { useChatStable } from "@/web/components/chat/context"; import { VirtualMCPPopoverContent, type VirtualMCPInfo, @@ -34,7 +34,7 @@ function AgentPreview({ icon?: string | null; }; }) { - const { setVirtualMcpId } = useChat(); + const { setVirtualMcpId } = useChatStable(); const handleClick = () => { // Select the agent in the chat context @@ -138,7 +138,7 @@ function SeeAllButton({ */ function AgentsListContent() { const virtualMcps = useVirtualMCPs(); - const { selectedVirtualMcp, setVirtualMcpId } = useChat(); + const { selectedVirtualMcp, setVirtualMcpId } = useChatStable(); // Filter out the default Decopilot agent (it's not a real agent) const agents = virtualMcps diff --git a/apps/mesh/src/web/hooks/create-sse-subscription.ts b/apps/mesh/src/web/hooks/create-sse-subscription.ts new file mode 100644 index 0000000000..9ae9d9ac47 --- /dev/null +++ b/apps/mesh/src/web/hooks/create-sse-subscription.ts @@ -0,0 +1,81 @@ +/** + * Shared SSE subscription factory + * + * Manages ref-counted EventSource connections so multiple React components + * can subscribe to the same SSE endpoint without opening duplicate connections. + * + * Each call to `createSSESubscription` creates an independent connection pool + * keyed by a caller-provided key (typically an orgId). + */ + +interface SharedConnection { + es: EventSource; + refCount: number; +} + +export interface SSESubscriptionOptions { + /** URL builder given a connection key */ + buildUrl: (key: string) => string; + /** SSE event types to listen for */ + eventTypes: string[]; +} + +export interface SSESubscription { + /** + * Subscribe to SSE events for the given key. + * Returns an unsubscribe function. + * + * Multiple subscribers share one EventSource per key; the connection + * is closed when the last subscriber unsubscribes. + */ + subscribe: (key: string, handler: (e: MessageEvent) => void) => () => void; +} + +export function createSSESubscription( + options: SSESubscriptionOptions, +): SSESubscription { + const { buildUrl, eventTypes } = options; + const connections = new Map(); + + function getOrCreate(key: string): SharedConnection { + let conn = connections.get(key); + if (!conn) { + const es = new EventSource(buildUrl(key)); + conn = { es, refCount: 0 }; + connections.set(key, conn); + + es.onerror = () => { + if (es.readyState === EventSource.CLOSED) { + connections.delete(key); + } + }; + } + return conn; + } + + return { + subscribe(key, handler) { + const conn = getOrCreate(key); + conn.refCount++; + + for (const type of eventTypes) { + conn.es.addEventListener(type, handler); + } + + let unsubscribed = false; + return () => { + if (unsubscribed) return; + unsubscribed = true; + + for (const type of eventTypes) { + conn.es.removeEventListener(type, handler); + } + conn.refCount--; + if (conn.refCount <= 0) { + conn.es.close(); + connections.delete(key); + } + }; + }, + }; +} diff --git a/apps/mesh/src/web/hooks/use-decopilot-events.ts b/apps/mesh/src/web/hooks/use-decopilot-events.ts new file mode 100644 index 0000000000..920b2b94b2 --- /dev/null +++ b/apps/mesh/src/web/hooks/use-decopilot-events.ts @@ -0,0 +1,145 @@ +/** + * useDecopilotEvents — Subscribe to typed decopilot SSE events + * + * Connects to the /org/:orgId/watch SSE endpoint, parses incoming events + * into the discriminated DecopilotSSEEvent union, filters by threadId when + * provided, and dispatches to typed handlers. + * + * Uses useSyncExternalStore for proper React 19 subscription lifecycle. + * EventSource connections are ref-counted so multiple call-sites share one + * connection per organization. + */ + +import { + DECOPILOT_EVENTS, + ALL_DECOPILOT_EVENT_TYPES, + type DecopilotSSEEvent, + type DecopilotStepEvent, + type DecopilotFinishEvent, + type DecopilotThreadStatusEvent, +} from "@decocms/mesh-sdk"; +import { useRef, useSyncExternalStore } from "react"; +import { createSSESubscription } from "./create-sse-subscription"; + +// ============================================================================ +// Shared connection pool +// ============================================================================ + +const decopilotSSE = createSSESubscription({ + buildUrl: (orgId) => { + const typesParam = ALL_DECOPILOT_EVENT_TYPES.join(","); + return `/org/${orgId}/watch?types=${typesParam}`; + }, + eventTypes: [...ALL_DECOPILOT_EVENT_TYPES], +}); + +const getSnapshot = () => 0; + +// ============================================================================ +// Hook +// ============================================================================ + +export interface UseDecopilotEventsOptions { + /** Organization ID for the SSE endpoint */ + orgId: string; + /** Only fire handlers for events matching this thread (omit for all threads) */ + threadId?: string; + /** Disable the SSE connection (default: true) */ + enabled?: boolean; + /** Called on each "decopilot.step" event (new content available) */ + onStep?: (event: DecopilotStepEvent) => void; + /** Called on each "decopilot.finish" event (stream ended) */ + onFinish?: (event: DecopilotFinishEvent) => void; + /** Called on each "decopilot.thread.status" event (thread status changed) */ + onThreadStatus?: (event: DecopilotThreadStatusEvent) => void; +} + +interface CallbacksRef { + threadId?: string; + onStep?: (event: DecopilotStepEvent) => void; + onFinish?: (event: DecopilotFinishEvent) => void; + onThreadStatus?: (event: DecopilotThreadStatusEvent) => void; +} + +/** + * Subscribe to decopilot SSE events with full type safety. + * + * The underlying EventSource is ref-counted per orgId, so multiple + * components can subscribe without opening duplicate connections. + * + * Callbacks and threadId are read from a ref so the `subscribe` function + * identity only changes when `enabled` or `orgId` change — keeping the + * EventSource connection stable across re-renders. + */ +export function useDecopilotEvents(options: UseDecopilotEventsOptions): void { + const { + orgId, + threadId, + enabled = true, + onStep, + onFinish, + onThreadStatus, + } = options; + + const callbacksRef = useRef({ + threadId, + onStep, + onFinish, + onThreadStatus, + }); + callbacksRef.current = { threadId, onStep, onFinish, onThreadStatus }; + + // `subscribe` only depends on `enabled` and `orgId` so the EventSource + // connection is not torn down when callbacks or threadId change. + const subscribeRef = useRef< + ((onStoreChange: () => void) => () => void) | null + >(null); + + const prevEnabled = useRef(enabled); + const prevOrgId = useRef(orgId); + + if ( + !subscribeRef.current || + prevEnabled.current !== enabled || + prevOrgId.current !== orgId + ) { + prevEnabled.current = enabled; + prevOrgId.current = orgId; + + subscribeRef.current = (onStoreChange: () => void) => { + if (!enabled || !orgId) { + return () => {}; + } + + const handler = (e: MessageEvent) => { + let event: DecopilotSSEEvent; + try { + event = JSON.parse(e.data) as DecopilotSSEEvent; + } catch { + return; + } + + const cb = callbacksRef.current; + if (cb.threadId && event.subject !== cb.threadId) return; + + switch (event.type) { + case DECOPILOT_EVENTS.STEP: + cb.onStep?.(event); + break; + case DECOPILOT_EVENTS.FINISH: + cb.onFinish?.(event); + break; + case DECOPILOT_EVENTS.THREAD_STATUS: + cb.onThreadStatus?.(event); + break; + } + + onStoreChange(); + }; + + return decopilotSSE.subscribe(orgId, handler); + }; + } + + useSyncExternalStore(subscribeRef.current, getSnapshot, getSnapshot); +} diff --git a/apps/mesh/src/web/routes/tasks.tsx b/apps/mesh/src/web/routes/tasks.tsx index 68d0d549ab..78fdd23d93 100644 --- a/apps/mesh/src/web/routes/tasks.tsx +++ b/apps/mesh/src/web/routes/tasks.tsx @@ -1,4 +1,3 @@ -import { useChat } from "@/web/components/chat"; import { CollectionDisplayButton } from "@/web/components/collections/collection-display-button.tsx"; import { CollectionSearch } from "@/web/components/collections/collection-search.tsx"; import { CollectionTableWrapper } from "@/web/components/collections/collection-table-wrapper.tsx"; @@ -16,6 +15,7 @@ import { SELF_MCP_ALIAS_ID, useMCPClient, useProjectContext, + type ThreadDisplayStatus, } from "@decocms/mesh-sdk"; import { Breadcrumb, @@ -33,16 +33,24 @@ import { Clock, } from "@untitledui/icons"; import { useNavigate } from "@tanstack/react-router"; -import { useSuspenseQuery } from "@tanstack/react-query"; -import { Suspense } from "react"; +import { useSuspenseQuery, useQueryClient } from "@tanstack/react-query"; +import { Suspense, useState } from "react"; +import { useDecopilotEvents } from "@/web/hooks/use-decopilot-events"; +import { useChatStable } from "../components/chat/context"; -function TaskStatusBadge({ status }: { status: string }) { +function TaskStatusBadge({ + status, + stepCount, +}: { + status: ThreadDisplayStatus; + stepCount?: number; +}) { switch (status) { case "in_progress": return ( - Running + {stepCount ? `Running · step ${stepCount}` : "Running"} ); case "completed": @@ -95,7 +103,31 @@ function TasksContent() { orgId: org.id, }); const navigate = useNavigate(); - const { switchToThread } = useChat(); + const { switchToThread } = useChatStable(); + const queryClient = useQueryClient(); + + const [stepCounts, setStepCounts] = useState>(new Map()); + + useDecopilotEvents({ + orgId: org.id, + onStep: (event) => { + setStepCounts((prev) => { + const next = new Map(prev); + next.set(event.subject, event.data.stepCount); + return next; + }); + }, + onThreadStatus: (event) => { + queryClient.invalidateQueries({ queryKey: KEYS.taskThreads(locator) }); + if (event.data.status !== "in_progress") { + setStepCounts((prev) => { + const next = new Map(prev); + next.delete(event.subject); + return next; + }); + } + }, + }); // useListState and ThreadEntity both use snake_case for audit fields const listState = useListState({ @@ -165,7 +197,10 @@ function TasksContent() { id: "status", header: "Status", render: (thread) => ( - + ), cellClassName: "w-40 shrink-0", sortable: true, diff --git a/packages/mesh-sdk/src/index.ts b/packages/mesh-sdk/src/index.ts index 366bbe74aa..23fc507a7b 100644 --- a/packages/mesh-sdk/src/index.ts +++ b/packages/mesh-sdk/src/index.ts @@ -104,6 +104,22 @@ export { type VirtualMCPCreateData, type VirtualMCPUpdateData, type VirtualMCPConnection, + // Decopilot event types + THREAD_STATUSES, + THREAD_DISPLAY_STATUSES, + DECOPILOT_EVENTS, + ALL_DECOPILOT_EVENT_TYPES, + createDecopilotStepEvent, + createDecopilotFinishEvent, + createDecopilotThreadStatusEvent, + type ThreadStatus, + type ThreadDisplayStatus, + type DecopilotEventType, + type DecopilotStepEvent, + type DecopilotFinishEvent, + type DecopilotThreadStatusEvent, + type DecopilotSSEEvent, + type DecopilotEventMap, } from "./types"; // Streamable HTTP transport diff --git a/packages/mesh-sdk/src/types/decopilot-events.ts b/packages/mesh-sdk/src/types/decopilot-events.ts new file mode 100644 index 0000000000..0b2a0bbfe4 --- /dev/null +++ b/packages/mesh-sdk/src/types/decopilot-events.ts @@ -0,0 +1,128 @@ +/** + * Decopilot SSE Event Types + * + * Canonical type definitions for thread statuses and decopilot SSE events. + * Shared between server (emitter) and client (consumer) for full type safety. + */ + +// ============================================================================ +// Thread Status +// ============================================================================ + +/** Persisted thread statuses (written to DB). */ +export const THREAD_STATUSES = [ + "in_progress", + "requires_action", + "failed", + "completed", +] as const; +export type ThreadStatus = (typeof THREAD_STATUSES)[number]; + +/** + * Display statuses include "expired" — a virtual status computed at read time + * for threads stuck in "in_progress" beyond a timeout threshold. + * Never persisted to DB, but appears in API responses and UI. + */ +export const THREAD_DISPLAY_STATUSES = [...THREAD_STATUSES, "expired"] as const; +export type ThreadDisplayStatus = (typeof THREAD_DISPLAY_STATUSES)[number]; + +// ============================================================================ +// SSE Event Type Constants +// ============================================================================ + +export const DECOPILOT_EVENTS = { + STEP: "decopilot.step", + FINISH: "decopilot.finish", + THREAD_STATUS: "decopilot.thread.status", +} as const; + +export type DecopilotEventType = + (typeof DECOPILOT_EVENTS)[keyof typeof DECOPILOT_EVENTS]; + +export const ALL_DECOPILOT_EVENT_TYPES: DecopilotEventType[] = + Object.values(DECOPILOT_EVENTS); + +// ============================================================================ +// Event Payloads (discriminated union on `type`) +// ============================================================================ + +interface BaseDecopilotEvent { + id: string; + source: "decopilot"; + /** Thread ID this event relates to */ + subject: string; + time: string; +} + +export interface DecopilotStepEvent extends BaseDecopilotEvent { + type: typeof DECOPILOT_EVENTS.STEP; + data: { stepCount: number }; +} + +export interface DecopilotFinishEvent extends BaseDecopilotEvent { + type: typeof DECOPILOT_EVENTS.FINISH; + data: { status: ThreadStatus }; +} + +export interface DecopilotThreadStatusEvent extends BaseDecopilotEvent { + type: typeof DECOPILOT_EVENTS.THREAD_STATUS; + data: { status: ThreadStatus }; +} + +export type DecopilotSSEEvent = + | DecopilotStepEvent + | DecopilotFinishEvent + | DecopilotThreadStatusEvent; + +/** Map from event type string → typed payload (useful for generic handlers) */ +export interface DecopilotEventMap { + [DECOPILOT_EVENTS.STEP]: DecopilotStepEvent; + [DECOPILOT_EVENTS.FINISH]: DecopilotFinishEvent; + [DECOPILOT_EVENTS.THREAD_STATUS]: DecopilotThreadStatusEvent; +} + +// ============================================================================ +// Server-side Factories (create typed events for SSEHub.emit) +// ============================================================================ + +export function createDecopilotStepEvent( + threadId: string, + stepCount: number, +): DecopilotStepEvent { + return { + id: crypto.randomUUID(), + type: DECOPILOT_EVENTS.STEP, + source: "decopilot", + subject: threadId, + data: { stepCount }, + time: new Date().toISOString(), + }; +} + +export function createDecopilotFinishEvent( + threadId: string, + status: ThreadStatus, +): DecopilotFinishEvent { + return { + id: crypto.randomUUID(), + type: DECOPILOT_EVENTS.FINISH, + source: "decopilot", + subject: threadId, + data: { status }, + time: new Date().toISOString(), + }; +} + +export function createDecopilotThreadStatusEvent( + threadId: string, + status: ThreadStatus, +): DecopilotThreadStatusEvent { + return { + id: crypto.randomUUID(), + type: DECOPILOT_EVENTS.THREAD_STATUS, + source: "decopilot", + subject: threadId, + data: { status }, + time: new Date().toISOString(), + }; +} diff --git a/packages/mesh-sdk/src/types/index.ts b/packages/mesh-sdk/src/types/index.ts index ab1b70003a..aa790a869a 100644 --- a/packages/mesh-sdk/src/types/index.ts +++ b/packages/mesh-sdk/src/types/index.ts @@ -24,3 +24,21 @@ export { type VirtualMCPUpdateData, type VirtualMCPConnection, } from "./virtual-mcp"; + +export { + THREAD_STATUSES, + THREAD_DISPLAY_STATUSES, + DECOPILOT_EVENTS, + ALL_DECOPILOT_EVENT_TYPES, + createDecopilotStepEvent, + createDecopilotFinishEvent, + createDecopilotThreadStatusEvent, + type ThreadStatus, + type ThreadDisplayStatus, + type DecopilotEventType, + type DecopilotStepEvent, + type DecopilotFinishEvent, + type DecopilotThreadStatusEvent, + type DecopilotSSEEvent, + type DecopilotEventMap, +} from "./decopilot-events";