diff --git a/README.md b/README.md index 967ee1c83..4d1037465 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,66 @@ Optional local services (provided in docker-compose.yml for dev): - NCPS (Nix cache proxy) on 8501 - Prometheus (9090), Grafana (3000), cAdvisor (8080) +## Authentication & OIDC setup (read before configuring environments) + +Agyn supports two authentication modes controlled by `AUTH_MODE`: + +- `single_user` (default): skips login and binds every request to the built-in `default@local` user (`00000000-0000-0000-0000-000000000001`). Use this only for air‑gapped demos—the default user owns every thread and there is no access control. +- `oidc`: enables the `/api/auth/login` → `/api/auth/oidc/callback` flow, persists users by issuer/subject, and issues signed `agyn_session` cookies per authenticated user. + +### Required environment in OIDC mode + +When `AUTH_MODE=oidc`, the server refuses to boot until the following are present: + +| Variable | Purpose | +| --- | --- | +| `AUTH_MODE=oidc` | Opt-in to federated auth. | +| `SESSION_SECRET` | 32+ character random string used to sign session cookies; must remain stable across restarts and replicas. | +| `OIDC_ISSUER_URL` | Discovery URL (e.g., `https://login.example.com/realms/agents`). | +| `OIDC_CLIENT_ID` | OAuth client identifier registered with your IdP. | +| `OIDC_CLIENT_SECRET` | Optional; supply when your IdP requires confidential clients. Leave blank only if the provider allows public clients. | +| `OIDC_REDIRECT_URI` | Must route to `https:///api/auth/oidc/callback`. This exact URI must also be registered with the IdP. | +| `OIDC_SCOPES` | Space/comma separated scopes (default `openid profile email`). | +| `OIDC_POST_LOGIN_REDIRECT` | Path relative to the UI origin to land on after login (default `/`). | + +Example `.env` excerpt for local testing: + +``` +AUTH_MODE=oidc +SESSION_SECRET=dev-0123456789abcdef0123456789abcdef +OIDC_ISSUER_URL=https://auth.local/realms/dev +OIDC_CLIENT_ID=agyn-local +OIDC_CLIENT_SECRET=local-secret +OIDC_REDIRECT_URI=http://localhost:3010/api/auth/oidc/callback +OIDC_SCOPES=openid profile email offline_access +OIDC_POST_LOGIN_REDIRECT=/threads +``` + +### Redirect + session behavior + +- The callback endpoint is always `GET /api/auth/oidc/callback`; set `OIDC_REDIRECT_URI` to this path on the API origin (`http://localhost:3010` in dev, your HTTPS hostname in prod). +- Successful callbacks create a 30-day `agyn_session` cookie (`HttpOnly`, `SameSite=Lax`, `Secure` in production). Clients must send this cookie on every request; the server verifies it using `SESSION_SECRET` and loads the user via Prisma. +- Logging out calls `POST /api/auth/logout`, deletes the server-side session row, and clears the cookie. + +### Local development tips + +1. **Same-origin (simplest):** Build and serve `platform-ui` through nginx (or run both services behind the same host/port). No extra CORS or credential settings are required. +2. **Cross-origin (Vite dev server → API):** + - Set `CORS_ORIGINS=http://localhost:5173` (or whatever hosts the UI) + - Ensure every UI fetch/axios call includes credentials, e.g. `fetch(url, { credentials: 'include' })` or `axios.create({ withCredentials: true })` + - Keep `VITE_API_BASE_URL` pointed at the API origin (e.g., `http://localhost:3010`) + - Update `OIDC_REDIRECT_URI` to the API origin even if the UI runs elsewhere; the IdP redirects into the API, which then forwards the browser to `OIDC_POST_LOGIN_REDIRECT`. +3. Restarting the server rotates the default (non-random) `SESSION_SECRET`; for cross-origin dev keep a stable secret in `.env` so cookies remain valid after reloads. + +### Troubleshooting + +- **`oidc_disabled` errors**: `AUTH_MODE` is still `single_user` or the server restarted without the OIDC env block. +- **Redirect loops or `invalid_grant`**: The IdP callback URL must exactly match `OIDC_REDIRECT_URI`, including scheme/port. Regenerate the client if needed. +- **Cookie missing in the browser**: Confirm `CORS_ORIGINS` allows the UI origin and the client sends requests with credentials. On HTTPS sites, ensure you are not hitting the API via plain HTTP because the cookie is marked `Secure` in production. +- **`Session cookie signature mismatch` warnings**: All replicas must share the same `SESSION_SECRET`; rotating it invalidates existing sessions. + +--- + ### Setup 1) Clone and install: ```bash diff --git a/packages/platform-server/.env.example b/packages/platform-server/.env.example index f0c2b3f91..278b1cdbb 100644 --- a/packages/platform-server/.env.example +++ b/packages/platform-server/.env.example @@ -17,6 +17,19 @@ LLM_PROVIDER= LITELLM_BASE_URL=http://127.0.0.1:4000 LITELLM_MASTER_KEY=sk-dev-master-1234 +# Authentication. Modes: single_user (default) or oidc +AUTH_MODE=single_user +# Must be at least 32 characters. Replace in production. +SESSION_SECRET=dev-session-secret-change-me-0123456789abcdef + +# OIDC (set AUTH_MODE=oidc to enable) +# OIDC_ISSUER_URL=https://your-idp/.well-known/openid-configuration +# OIDC_CLIENT_ID= +# OIDC_CLIENT_SECRET= +# OIDC_REDIRECT_URI=http://localhost:3010/api/auth/oidc/callback +# OIDC_SCOPES=openid profile email +# OIDC_POST_LOGIN_REDIRECT=http://localhost:4173 + # Optional: GitHub integration (App or PAT). Safe to omit for local dev. # GITHUB_APP_ID= # GITHUB_APP_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n" diff --git a/packages/platform-server/__e2e__/graph.socket.gateway.e2e.test.ts b/packages/platform-server/__e2e__/graph.socket.gateway.e2e.test.ts index ff5ea278a..c44bbc494 100644 --- a/packages/platform-server/__e2e__/graph.socket.gateway.e2e.test.ts +++ b/packages/platform-server/__e2e__/graph.socket.gateway.e2e.test.ts @@ -14,8 +14,10 @@ import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; import { LiveGraphRuntime } from '../src/graph-core/liveGraph.manager'; import { ThreadsMetricsService } from '../src/agents/threads.metrics.service'; import { PrismaService } from '../src/core/services/prisma.service'; +import { ConfigService } from '../src/core/services/config.service'; import { ContainerTerminalGateway } from '../src/infra/container/terminal.gateway'; import { TerminalSessionsService, type TerminalSessionRecord } from '../src/infra/container/terminal.sessions.service'; +import { AuthService } from '../src/auth/auth.service'; import { WorkspaceProvider, type WorkspaceKey, @@ -47,18 +49,32 @@ class ThreadsMetricsServiceStub { class PrismaServiceStub { private readonly runEvents = new Map(); + private readonly threadOwners = new Map(); setRunEvent(event: { id: string }): void { this.runEvents.set(event.id, event); } + setThreadOwner(threadId: string, ownerUserId: string): void { + this.threadOwners.set(threadId, ownerUserId); + } + clear(): void { this.runEvents.clear(); + this.threadOwners.clear(); } getClient() { return { $queryRaw: async () => [], + thread: { + findUnique: async ({ where }: { where: { id: string } }) => { + const id = where?.id; + if (!id) return null; + const ownerUserId = this.threadOwners.get(id); + return ownerUserId ? { id, ownerUserId } : null; + }, + }, runEvent: { findUnique: async ({ where }: { where: { id: string } }) => { const id = where?.id; @@ -154,6 +170,17 @@ class WorkspaceProviderStub extends WorkspaceProvider { } } +class AuthServiceStub { + async resolvePrincipalFromCookieHeader(): Promise<{ userId: string }> { + return { userId: 'test-user' }; + } +} + +class ConfigServiceStub { + corsOrigins: string[] | null = null; + isProduction = false; +} + class TerminalSessionsServiceStub { public connected = false; public closed = false; @@ -322,6 +349,8 @@ describe('Socket gateway real server handshakes', () => { { provide: WorkspaceProvider, useClass: WorkspaceProviderStub }, EventsBusService, RunEventsService, + { provide: AuthService, useClass: AuthServiceStub }, + { provide: ConfigService, useClass: ConfigServiceStub }, ], }).compile(); @@ -389,6 +418,8 @@ describe('Socket gateway real server handshakes', () => { const threadId = 'thread-123'; const runId = 'run-456'; + const ownerUserId = 'test-user'; + prismaStub.setThreadOwner(threadId, ownerUserId); const messagePromise = new Promise>((resolve, reject) => { const timer = setTimeout(() => { @@ -448,7 +479,7 @@ describe('Socket gateway real server handshakes', () => { expect(ack.rooms).toEqual(expect.arrayContaining(['threads', `thread:${threadId}`, `run:${runId}`])); const createdAt = new Date(); - graphGateway.emitMessageCreated(threadId, { + graphGateway.emitMessageCreated(threadId, ownerUserId, { id: 'msg-1', kind: 'assistant' as MessageKind, text: 'hello world', @@ -457,11 +488,15 @@ describe('Socket gateway real server handshakes', () => { runId, }); - graphGateway.emitRunStatusChanged(threadId, { - id: runId, - status: 'running' as RunStatus, - createdAt, - updatedAt: createdAt, + graphGateway.emitRunStatusChanged({ + threadId, + ownerUserId, + run: { + id: runId, + status: 'running' as RunStatus, + createdAt, + updatedAt: createdAt, + }, }); const runEventId = 'evt-1'; @@ -620,6 +655,7 @@ describe('Socket gateway real server handshakes', () => { const threadId = 'thread-999'; const runId = 'run-999'; + prismaStub.setThreadOwner(threadId, 'test-user'); const subscribeAck = await new Promise<{ ok: boolean; rooms?: string[]; error?: string }>((resolve, reject) => { const timer = setTimeout(() => reject(new Error('Timed out waiting for subscribe ack')), 3000); diff --git a/packages/platform-server/__tests__/agents.fail_fast.test.ts b/packages/platform-server/__tests__/agents.fail_fast.test.ts index 20b3e4929..ea6a64ca4 100644 --- a/packages/platform-server/__tests__/agents.fail_fast.test.ts +++ b/packages/platform-server/__tests__/agents.fail_fast.test.ts @@ -13,6 +13,8 @@ import { LiveGraphRuntime } from '../src/graph-core/liveGraph.manager'; import { TemplateRegistry } from '../src/graph-core/templateRegistry'; import { RemindersService } from '../src/agents/reminders.service'; +const principal = { userId: 'user-1' } as any; + class StubLLMProvisioner extends LLMProvisioner { async init(): Promise {} async getLLM(): Promise<{ call: (messages: unknown) => Promise<{ text: string; output: unknown[] }> }> { @@ -95,6 +97,6 @@ describe('Fail-fast behavior', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - await expect(ctrl.listThreads({} as any)).rejects.toBeTruthy(); + await expect(ctrl.listThreads({} as any, principal)).rejects.toBeTruthy(); }); }); diff --git a/packages/platform-server/__tests__/agents.persistence.ensureThreadByAlias.test.ts b/packages/platform-server/__tests__/agents.persistence.ensureThreadByAlias.test.ts index 7ba91faea..e06f274b1 100644 --- a/packages/platform-server/__tests__/agents.persistence.ensureThreadByAlias.test.ts +++ b/packages/platform-server/__tests__/agents.persistence.ensureThreadByAlias.test.ts @@ -4,6 +4,7 @@ import { createPrismaStub, StubPrismaService } from './helpers/prisma.stub'; import { createRunEventsStub } from './helpers/runEvents.stub'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; import { createEventsBusStub } from './helpers/eventsBus.stub'; +import { createUserServiceStub } from './helpers/userService.stub'; const metricsStub = { getThreadsMetrics: async () => ({}) } as any; const templateRegistryStub = { toSchema: async () => [], getMeta: () => undefined } as any; @@ -39,6 +40,7 @@ const createService = (stub: any) => { createRunEventsStub() as any, createLinkingStub(), eventsBusStub, + createUserServiceStub(), ); (svc as any).__eventsBusStub = eventsBusStub; return svc; diff --git a/packages/platform-server/__tests__/agents.persistence.extractKindText.test.ts b/packages/platform-server/__tests__/agents.persistence.extractKindText.test.ts index a64ae26cc..3616f13c8 100644 --- a/packages/platform-server/__tests__/agents.persistence.extractKindText.test.ts +++ b/packages/platform-server/__tests__/agents.persistence.extractKindText.test.ts @@ -61,6 +61,7 @@ import type { ResponseFunctionToolCall } from 'openai/resources/responses/respon import { createRunEventsStub } from './helpers/runEvents.stub'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; import { createEventsBusStub } from './helpers/eventsBus.stub'; +import { createUserServiceStub } from './helpers/userService.stub'; const templateRegistryStub = { toSchema: async () => [], getMeta: () => undefined } as any; const graphRepoStub = { @@ -97,6 +98,7 @@ function makeService(): InstanceType { createRunEventsStub() as any, createLinkingStub(), eventsBusStub, + createUserServiceStub(), ); return svc; } @@ -155,6 +157,7 @@ describe('AgentsPersistenceService beginRun/completeRun populates Message.text', createRunEventsStub() as any, linking, eventsBusStub, + createUserServiceStub(), ); // Begin run with user + system messages @@ -179,6 +182,9 @@ describe('AgentsPersistenceService beginRun/completeRun populates Message.text', const createdRunMessages: any[] = []; const runs: any[] = [{ id: 'run-1', threadId: 'thread-1', status: 'running' }]; const prismaMock = { + thread: { + findUnique: async ({ where }: any) => ({ id: where.id, ownerUserId: 'user-test' }), + }, run: { findUnique: async ({ where }: any) => runs.find((x) => x.id === where.id) ?? null, }, @@ -210,6 +216,7 @@ describe('AgentsPersistenceService beginRun/completeRun populates Message.text', runEventsStub as any, linking, eventsBusStub, + createUserServiceStub(), ); const result = await svc.recordTransportAssistantMessage({ @@ -231,10 +238,13 @@ describe('AgentsPersistenceService beginRun/completeRun populates Message.text', role: 'assistant', }), ); - expect(eventsBusStub.emitMessageCreated).toHaveBeenCalledWith({ - threadId: 'thread-1', - message: expect.objectContaining({ id: 'm1', kind: 'assistant', text: 'final reply', runId: 'run-1' }), - }); + expect(eventsBusStub.emitMessageCreated).toHaveBeenCalledWith( + expect.objectContaining({ + threadId: 'thread-1', + ownerUserId: 'user-test', + message: expect.objectContaining({ id: 'm1', kind: 'assistant', text: 'final reply', runId: 'run-1' }), + }), + ); }); it('recordTransportAssistantMessage skips invocation event for send_message source', async () => { @@ -242,6 +252,9 @@ describe('AgentsPersistenceService beginRun/completeRun populates Message.text', const createdRunMessages: any[] = []; const runs: any[] = [{ id: 'run-1', threadId: 'thread-1', status: 'running' }]; const prismaMock = { + thread: { + findUnique: async ({ where }: any) => ({ id: where.id, ownerUserId: 'user-test' }), + }, run: { findUnique: async ({ where }: any) => runs.find((x) => x.id === where.id) ?? null, }, @@ -273,6 +286,7 @@ describe('AgentsPersistenceService beginRun/completeRun populates Message.text', runEventsStub as any, linking, eventsBusStub, + createUserServiceStub(), ); const result = await svc.recordTransportAssistantMessage({ @@ -285,9 +299,12 @@ describe('AgentsPersistenceService beginRun/completeRun populates Message.text', expect(result).toEqual({ messageId: 'm1' }); expect(createdRunMessages).toEqual([{ runId: 'run-1', messageId: 'm1', type: 'output' }]); expect(runEventsStub.recordInvocationMessage).not.toHaveBeenCalled(); - expect(eventsBusStub.emitMessageCreated).toHaveBeenCalledWith({ - threadId: 'thread-1', - message: expect.objectContaining({ id: 'm1', text: 'fallback reply', runId: 'run-1' }), - }); + expect(eventsBusStub.emitMessageCreated).toHaveBeenCalledWith( + expect.objectContaining({ + threadId: 'thread-1', + ownerUserId: 'user-test', + message: expect.objectContaining({ id: 'm1', text: 'fallback reply', runId: 'run-1' }), + }), + ); }); }); diff --git a/packages/platform-server/__tests__/agents.persistence.metrics_titles.test.ts b/packages/platform-server/__tests__/agents.persistence.metrics_titles.test.ts index a0737505d..60970999a 100644 --- a/packages/platform-server/__tests__/agents.persistence.metrics_titles.test.ts +++ b/packages/platform-server/__tests__/agents.persistence.metrics_titles.test.ts @@ -4,6 +4,7 @@ import { StubPrismaService, createPrismaStub } from './helpers/prisma.stub'; import { createRunEventsStub } from './helpers/runEvents.stub'; import { createEventsBusStub } from './helpers/eventsBus.stub'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; +import { createUserServiceStub } from './helpers/userService.stub'; const createLinkingStub = (overrides?: Partial) => ({ @@ -53,6 +54,7 @@ function createService( createRunEventsStub() as any, overrides?.linking ?? createLinkingStub(), eventsBusStub, + createUserServiceStub(), ); return svc; } diff --git a/packages/platform-server/__tests__/agents.reminders.controller.test.ts b/packages/platform-server/__tests__/agents.reminders.controller.test.ts index eae33fa6b..9cb590a91 100644 --- a/packages/platform-server/__tests__/agents.reminders.controller.test.ts +++ b/packages/platform-server/__tests__/agents.reminders.controller.test.ts @@ -9,12 +9,15 @@ import { RemindersService } from '../src/agents/reminders.service'; import { createRunEventsStub } from './helpers/runEvents.stub'; import { createEventsBusStub } from './helpers/eventsBus.stub'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; +import { createUserServiceStub } from './helpers/userService.stub'; const templateRegistryStub = { toSchema: async () => [], getMeta: () => undefined } as any; const graphRepoStub = { get: async () => ({ name: 'main', version: 1, updatedAt: new Date().toISOString(), nodes: [], edges: [] }), } as any; +const principal = { userId: 'user-1' } as any; + const createLinkingStub = () => ({ buildInitialMetadata: (params: { tool: 'call_agent' | 'call_engineer'; parentThreadId: string; childThreadId: string }) => ({ @@ -53,6 +56,7 @@ function createPersistenceWithTx(tx: { reminder: { findMany: any; count: any }; createRunEventsStub() as any, createLinkingStub(), createEventsBusStub(), + createUserServiceStub(), ); } @@ -71,9 +75,9 @@ describe('AgentsRemindersController', () => { }).compile(); const ctrl = await module.resolve(AgentsRemindersController); - const res = await ctrl.listReminders({}); + const res = await ctrl.listReminders({}, principal); - expect(svc.listReminders).toHaveBeenCalledWith('active', 100, undefined); + expect(svc.listReminders).toHaveBeenCalledWith('active', 100, undefined, principal.userId); expect(svc.listRemindersPaginated).not.toHaveBeenCalled(); expect(res).toEqual({ items: [{ id: '1' }] }); }); @@ -91,6 +95,7 @@ describe('AgentsRemindersController', () => { const svc = { listReminders: vi.fn(), listRemindersPaginated: vi.fn(async () => paginatedResponse), + getThreadById: vi.fn(async () => ({ id: 'aaaa1111-1111-1111-1111-111111111111', ownerUserId: principal.userId })), } as unknown as AgentsPersistenceService; const module = await Test.createTestingModule({ controllers: [AgentsRemindersController], @@ -101,7 +106,7 @@ describe('AgentsRemindersController', () => { }).compile(); const ctrl = await module.resolve(AgentsRemindersController); - const result = await ctrl.listReminders({ page: 2, threadId: 'aaaa1111-1111-1111-1111-111111111111' }); + const result = await ctrl.listReminders({ page: 2, threadId: 'aaaa1111-1111-1111-1111-111111111111' }, principal); expect(svc.listReminders).not.toHaveBeenCalled(); expect(svc.listRemindersPaginated).toHaveBeenCalledWith({ @@ -111,6 +116,7 @@ describe('AgentsRemindersController', () => { sort: 'latest', order: 'desc', threadId: 'aaaa1111-1111-1111-1111-111111111111', + ownerUserId: principal.userId, }); expect(result).toEqual(paginatedResponse); }); @@ -132,9 +138,9 @@ describe('AgentsRemindersController', () => { }).compile(); const ctrl = await module.resolve(AgentsRemindersController); - const res = await ctrl.cancelReminder('rem-1'); + const res = await ctrl.cancelReminder('rem-1', principal); - expect(reminders.cancelReminder).toHaveBeenCalledWith({ reminderId: 'rem-1', emitMetrics: true }); + expect(reminders.cancelReminder).toHaveBeenCalledWith({ reminderId: 'rem-1', emitMetrics: true, ownerUserId: principal.userId }); expect(res).toEqual({ ok: true, threadId: 'thread-9' }); }); @@ -156,7 +162,7 @@ describe('AgentsRemindersController', () => { const ctrl = await module.resolve(AgentsRemindersController); - await expect(ctrl.cancelReminder('missing')).rejects.toBeInstanceOf(NotFoundException); + await expect(ctrl.cancelReminder('missing', principal)).rejects.toBeInstanceOf(NotFoundException); }); it('throws 404 when reminders service omits thread id', async () => { @@ -177,7 +183,7 @@ describe('AgentsRemindersController', () => { const ctrl = await module.resolve(AgentsRemindersController); - await expect(ctrl.cancelReminder('rem-2')).rejects.toBeInstanceOf(NotFoundException); + await expect(ctrl.cancelReminder('rem-2', principal)).rejects.toBeInstanceOf(NotFoundException); }); }); @@ -227,6 +233,7 @@ describe('AgentsPersistenceService.listReminders', () => { createRunEventsStub() as any, createLinkingStub(), eventsBusStub, + createUserServiceStub(), ); await svc.listReminders('active', 50); @@ -264,6 +271,7 @@ describe('AgentsPersistenceService.listReminders', () => { createRunEventsStub() as any, createLinkingStub(), eventsBusStub, + createUserServiceStub(), ); const errorSpy = vi.spyOn(Logger.prototype, 'error').mockImplementation(() => {}); diff --git a/packages/platform-server/__tests__/agents.threads.controller.create.test.ts b/packages/platform-server/__tests__/agents.threads.controller.create.test.ts index c92009415..c8d3616ba 100644 --- a/packages/platform-server/__tests__/agents.threads.controller.create.test.ts +++ b/packages/platform-server/__tests__/agents.threads.controller.create.test.ts @@ -15,10 +15,13 @@ const runEventsStub = { getToolOutputSnapshot: async () => null, }; +const principal = { userId: 'user-1' } as any; + type SetupOptions = { nodes?: Array<{ id: string; template: string; instance: { status: string; invoke: ReturnType } }>; templateMeta?: Record; createThreadWithInitialMessage?: ReturnType; + getThreadById?: ReturnType; }; async function setup(options: SetupOptions = {}) { @@ -45,6 +48,8 @@ async function setup(options: SetupOptions = {}) { assignedAgentNodeId: 'agent-1', })); + const getThreadById = options.getThreadById ?? vi.fn(async () => null); + const templateRegistryStub = { getMeta: (template: string) => options.templateMeta?.[template] ?? { kind: 'agent', title: template }, } satisfies Pick; @@ -62,7 +67,7 @@ async function setup(options: SetupOptions = {}) { getThreadsMetrics: async () => ({}), getThreadsAgentTitles: async () => ({}), updateThread: async () => ({ previousStatus: 'open', status: 'open' }), - getThreadById: async () => null, + getThreadById, getLatestAgentNodeIdForThread: async () => null, getRunById: async () => null, ensureAssignedAgent: async () => {}, @@ -90,7 +95,7 @@ describe('AgentsThreadsController POST /api/agents/threads', () => { it('returns bad_message_payload when text is missing', async () => { const { controller, createThreadWithInitialMessage } = await setup(); - await expect(controller.createThread({ agentNodeId: 'agent-1' } as any)).rejects.toMatchObject({ + await expect(controller.createThread({ agentNodeId: 'agent-1' } as any, principal)).rejects.toMatchObject({ status: 400, response: { error: 'bad_message_payload' }, }); @@ -101,7 +106,7 @@ describe('AgentsThreadsController POST /api/agents/threads', () => { it('returns bad_message_payload when agentNodeId is missing', async () => { const { controller, createThreadWithInitialMessage } = await setup(); - await expect(controller.createThread({ text: 'hello there' } as any)).rejects.toMatchObject({ + await expect(controller.createThread({ text: 'hello there' } as any, principal)).rejects.toMatchObject({ status: 400, response: { error: 'bad_message_payload' }, }); @@ -113,7 +118,7 @@ describe('AgentsThreadsController POST /api/agents/threads', () => { const { controller, createThreadWithInitialMessage } = await setup(); await expect( - controller.createThread({ text: 'a'.repeat(100001), agentNodeId: 'agent-1' } as any), + controller.createThread({ text: 'a'.repeat(100001), agentNodeId: 'agent-1' } as any, principal), ).rejects.toMatchObject({ status: 400, response: { error: 'bad_message_payload' }, @@ -126,10 +131,26 @@ describe('AgentsThreadsController POST /api/agents/threads', () => { const createThreadWithInitialMessage = vi.fn(async () => { throw new ThreadParentNotFoundError(); }); - const { controller } = await setup({ createThreadWithInitialMessage }); + const getThreadById = vi.fn(async () => ({ id: 'parent', ownerUserId: principal.userId })); + const { controller } = await setup({ createThreadWithInitialMessage, getThreadById }); + + await expect( + controller.createThread({ text: 'hello', agentNodeId: 'agent-1', parentId: 'missing-parent' } as any, principal), + ).rejects.toMatchObject({ + status: 404, + response: { error: 'parent_not_found' }, + }); + }); + + it('maps thread_parent_owner_mismatch errors to parent_not_found', async () => { + const createThreadWithInitialMessage = vi.fn(async () => { + throw new Error('thread_parent_owner_mismatch'); + }); + const getThreadById = vi.fn(async () => ({ id: 'parent', ownerUserId: principal.userId })); + const { controller } = await setup({ createThreadWithInitialMessage, getThreadById }); await expect( - controller.createThread({ text: 'hello', agentNodeId: 'agent-1', parentId: 'missing-parent' } as any), + controller.createThread({ text: 'hello', agentNodeId: 'agent-1', parentId: 'parent-1' } as any, principal), ).rejects.toMatchObject({ status: 404, response: { error: 'parent_not_found' }, diff --git a/packages/platform-server/__tests__/agents.threads.controller.list.test.ts b/packages/platform-server/__tests__/agents.threads.controller.list.test.ts index d403d075b..c11f7216e 100644 --- a/packages/platform-server/__tests__/agents.threads.controller.list.test.ts +++ b/packages/platform-server/__tests__/agents.threads.controller.list.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect, vi } from 'vitest'; import { Test } from '@nestjs/testing'; +import { NotFoundException } from '@nestjs/common'; import { AgentsThreadsController } from '../src/agents/threads.controller'; import { AgentsPersistenceService } from '../src/agents/agents.persistence.service'; import { RunSignalsRegistry } from '../src/agents/run-signals.service'; @@ -14,6 +15,8 @@ const runEventsStub = { listRunEvents: vi.fn(async () => []), } as unknown as RunEventsService; +const principal = { userId: 'user-1' } as any; + describe('AgentsThreadsController list endpoints', () => { it('requests metrics and agent titles when flags are enabled', async () => { const now = new Date(); @@ -43,9 +46,9 @@ describe('AgentsThreadsController list endpoints', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const res = await ctrl.listThreads({ includeMetrics: 'true', includeAgentTitles: 'true' } as any); + const res = await ctrl.listThreads({ includeMetrics: 'true', includeAgentTitles: 'true' } as any, principal); - expect((persistence.listThreads as any).mock.calls.length).toBe(1); + expect(persistence.listThreads).toHaveBeenCalledWith({ rootsOnly: false, status: 'all', limit: 100, ownerUserId: principal.userId }); expect((persistence.getThreadsMetrics as any).mock.calls[0][0]).toEqual(['t1']); expect((persistence.getThreadsAgentDescriptors as any).mock.calls[0][0]).toEqual(['t1']); expect(res).toMatchObject({ @@ -94,7 +97,7 @@ describe('AgentsThreadsController list endpoints', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const res = await ctrl.listThreads({} as any); + const res = await ctrl.listThreads({} as any, principal); expect((persistence.getThreadsMetrics as any).mock?.calls?.length ?? 0).toBe(0); expect(res.items[0]).toMatchObject({ agentRole: 'Support', agentName: 'Beta' }); @@ -114,6 +117,7 @@ describe('AgentsThreadsController list endpoints', () => { listRuns: vi.fn(), listRunMessages: vi.fn(), updateThread: vi.fn(), + getThreadById: vi.fn(async () => ({ id: 't1', ownerUserId: principal.userId })), } as unknown as AgentsPersistenceService; const module = await Test.createTestingModule({ @@ -130,7 +134,7 @@ describe('AgentsThreadsController list endpoints', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const res = await ctrl.listChildren('t1', { includeMetrics: 'true', includeAgentTitles: 'true' } as any); + const res = await ctrl.listChildren('t1', { includeMetrics: 'true', includeAgentTitles: 'true' } as any, principal); expect(res.items[0].metrics).toEqual({ remindersCount: 0, containersCount: 0, activity: 'idle', runsCount: 0 }); expect(res.items[0].agentTitle).toBe('(unknown agent)'); @@ -150,6 +154,7 @@ describe('AgentsThreadsController list endpoints', () => { listRuns: vi.fn(), listRunMessages: vi.fn(), updateThread: vi.fn(), + getThreadById: vi.fn(async () => ({ id: 't1', ownerUserId: principal.userId })), } as unknown as AgentsPersistenceService; const module = await Test.createTestingModule({ @@ -166,7 +171,7 @@ describe('AgentsThreadsController list endpoints', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const res = await ctrl.listChildren('t1', {} as any); + const res = await ctrl.listChildren('t1', {} as any, principal); expect(res.items[0]).toMatchObject({ agentName: 'Child', agentRole: 'Helper' }); expect(res.items[0].createdAt).toBeInstanceOf(Date); @@ -181,7 +186,7 @@ describe('AgentsThreadsController list endpoints', () => { listRunMessages: vi.fn(), updateThread: vi.fn(), getThreadsAgentDescriptors: vi.fn(), - getThreadById: vi.fn(), + getThreadById: vi.fn(async () => ({ id: 't-miss', ownerUserId: principal.userId })), } as unknown as AgentsPersistenceService; const module = await Test.createTestingModule({ @@ -198,15 +203,15 @@ describe('AgentsThreadsController list endpoints', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const res = await ctrl.getThreadMetrics('t-miss'); + const res = await ctrl.getThreadMetrics('t-miss', principal); expect(res).toEqual({ remindersCount: 0, containersCount: 0, activity: 'idle', runsCount: 0 }); }); it('getThread returns defaults for metrics and titles when missing', async () => { const now = new Date(); const persistence = { - getThreadById: vi.fn(async (_id: string, opts: { includeMetrics?: boolean; includeAgentTitles?: boolean }) => { - expect(opts).toEqual({ includeMetrics: true, includeAgentTitles: true }); + getThreadById: vi.fn(async (_id: string, opts: { includeMetrics?: boolean; includeAgentTitles?: boolean; ownerUserId?: string }) => { + expect(opts).toEqual({ includeMetrics: true, includeAgentTitles: true, ownerUserId: principal.userId }); return { id: 't1', alias: 'alias', @@ -241,7 +246,7 @@ describe('AgentsThreadsController list endpoints', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const result = await ctrl.getThread('t1', { includeMetrics: 'true', includeAgentTitles: 'true' } as any); + const result = await ctrl.getThread('t1', { includeMetrics: 'true', includeAgentTitles: 'true' } as any, principal); expect(result).toMatchObject({ id: 't1', @@ -258,7 +263,7 @@ describe('AgentsThreadsController list endpoints', () => { it('getThread forwards agent name and role without optional flags', async () => { const now = new Date(); const persistence = { - getThreadById: vi.fn(async () => ({ + getThreadById: vi.fn(async (_id: string, opts: { ownerUserId?: string }) => ({ id: 't2', alias: 'alias', summary: 'Summary', @@ -267,6 +272,7 @@ describe('AgentsThreadsController list endpoints', () => { parentId: null, agentName: 'Agent X', agentRole: 'Planner', + ownerUserId: opts.ownerUserId, })), listThreads: vi.fn(), listChildren: vi.fn(), @@ -291,7 +297,7 @@ describe('AgentsThreadsController list endpoints', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const result = await ctrl.getThread('t2', {} as any); + const result = await ctrl.getThread('t2', {} as any, principal); expect(result).toMatchObject({ agentName: 'Agent X', agentRole: 'Planner' }); }); @@ -322,7 +328,7 @@ describe('AgentsThreadsController list endpoints', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - await expect(ctrl.getThread('missing', {} as any)).rejects.toThrowError('thread_not_found'); + await expect(ctrl.getThread('missing', {} as any, principal)).rejects.toThrow(NotFoundException); }); }); diff --git a/packages/platform-server/__tests__/agents.threads.controller.patch.test.ts b/packages/platform-server/__tests__/agents.threads.controller.patch.test.ts index 07d821371..06c476e2e 100644 --- a/packages/platform-server/__tests__/agents.threads.controller.patch.test.ts +++ b/packages/platform-server/__tests__/agents.threads.controller.patch.test.ts @@ -9,6 +9,8 @@ import { LiveGraphRuntime } from '../src/graph-core/liveGraph.manager'; import { TemplateRegistry } from '../src/graph-core/templateRegistry'; import { RemindersService } from '../src/agents/reminders.service'; +const principal = { userId: 'user-1' } as any; + const runEventsStub = { getRunSummary: async () => ({ status: 'unknown', @@ -61,9 +63,9 @@ describe('AgentsThreadsController PATCH threads/:id', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - await ctrl.patchThread('t1', { summary: null }); + await ctrl.patchThread('t1', { summary: null }, principal); expect(closeCascade).not.toHaveBeenCalled(); - await ctrl.patchThread('t2', { status: 'closed' }); + await ctrl.patchThread('t2', { status: 'closed' }, principal); expect(updates).toEqual([ { id: 't1', data: { summary: null } }, @@ -99,9 +101,9 @@ describe('AgentsThreadsController PATCH threads/:id', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - await ctrl.patchThread('closed-thread', { status: 'closed' }); + await ctrl.patchThread('closed-thread', { status: 'closed' }, principal); - expect(updateThread).toHaveBeenCalledWith('closed-thread', { status: 'closed' }); + expect(updateThread).toHaveBeenCalledWith('closed-thread', { status: 'closed' }, { ownerUserId: principal.userId }); expect(closeCascade).toHaveBeenCalledWith('closed-thread'); }); @@ -131,7 +133,7 @@ describe('AgentsThreadsController PATCH threads/:id', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - await ctrl.patchThread('already-closed', { status: 'closed' }); + await ctrl.patchThread('already-closed', { status: 'closed' }, principal); expect(closeCascade).not.toHaveBeenCalled(); }); diff --git a/packages/platform-server/__tests__/agents.threads.controller.queued-messages.test.ts b/packages/platform-server/__tests__/agents.threads.controller.queued-messages.test.ts index 1f1a83454..e4e451ce4 100644 --- a/packages/platform-server/__tests__/agents.threads.controller.queued-messages.test.ts +++ b/packages/platform-server/__tests__/agents.threads.controller.queued-messages.test.ts @@ -10,6 +10,8 @@ import { TemplateRegistry } from '../src/graph-core/templateRegistry'; import { InternalServerErrorException, NotFoundException } from '@nestjs/common'; import { RemindersService } from '../src/agents/reminders.service'; +const principal = { userId: 'user-1' } as any; + const runEventsStub = { getRunSummary: async () => null, listRunEvents: async () => ({ items: [], nextCursor: null }), @@ -113,7 +115,7 @@ describe('AgentsThreadsController GET /api/agents/threads/:threadId/queued-messa nodes: [{ id: 'agent-1', template: 'agent', instance: { status: 'ready', invoke: vi.fn(), listQueuedPreview } }], }); - const result = await controller.listQueuedMessages('thread-1'); + const result = await controller.listQueuedMessages('thread-1', principal); expect(result.items).toEqual([ { id: 'msg-1', text: 'hello', enqueuedAt: new Date(1700000000000).toISOString() }, @@ -127,7 +129,7 @@ describe('AgentsThreadsController GET /api/agents/threads/:threadId/queued-messa nodes: [], }); - const result = await controller.listQueuedMessages('thread-1'); + const result = await controller.listQueuedMessages('thread-1', principal); expect(result).toEqual({ items: [] }); }); @@ -135,7 +137,7 @@ describe('AgentsThreadsController GET /api/agents/threads/:threadId/queued-messa it('throws when thread does not exist', async () => { const { controller } = await setup({ thread: null }); - await expect(controller.listQueuedMessages('missing-thread')).rejects.toBeInstanceOf(NotFoundException); + await expect(controller.listQueuedMessages('missing-thread', principal)).rejects.toBeInstanceOf(NotFoundException); }); it('normalizes missing text fields to empty string', async () => { @@ -145,7 +147,7 @@ describe('AgentsThreadsController GET /api/agents/threads/:threadId/queued-messa nodes: [{ id: 'agent-1', template: 'agent', instance: { status: 'ready', invoke: vi.fn(), listQueuedPreview } }], }); - const result = await controller.listQueuedMessages('thread-1'); + const result = await controller.listQueuedMessages('thread-1', principal); expect(result.items).toEqual([ { id: 'msg-2', text: '', enqueuedAt: new Date(1700000001000).toISOString() }, @@ -172,7 +174,7 @@ describe('AgentsThreadsController DELETE /api/agents/threads/:threadId/queued-me ], }); - const result = await controller.clearQueuedMessages('thread-1'); + const result = await controller.clearQueuedMessages('thread-1', principal); expect(result).toEqual({ clearedCount: 5 }); expect(clearQueuedMessages).toHaveBeenCalledWith('thread-1'); @@ -190,7 +192,7 @@ describe('AgentsThreadsController DELETE /api/agents/threads/:threadId/queued-me ], }); - const result = await controller.clearQueuedMessages('thread-1'); + const result = await controller.clearQueuedMessages('thread-1', principal); expect(result).toEqual({ clearedCount: 0 }); }); @@ -215,13 +217,13 @@ describe('AgentsThreadsController DELETE /api/agents/threads/:threadId/queued-me ], }); - await expect(controller.clearQueuedMessages('thread-1')).rejects.toBeInstanceOf(InternalServerErrorException); + await expect(controller.clearQueuedMessages('thread-1', principal)).rejects.toBeInstanceOf(InternalServerErrorException); }); it('throws when thread is missing', async () => { const { controller } = await setup({ thread: null }); - await expect(controller.clearQueuedMessages('thread-1')).rejects.toBeInstanceOf(NotFoundException); + await expect(controller.clearQueuedMessages('thread-1', principal)).rejects.toBeInstanceOf(NotFoundException); }); }); @@ -230,7 +232,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/reminders/c const cancelThreadReminders = vi.fn(async () => ({ cancelledDb: 2, clearedRuntime: 1 })); const { controller } = await setup({ remindersService: { cancelThreadReminders } }); - const result = await controller.cancelThreadReminders('thread-1'); + const result = await controller.cancelThreadReminders('thread-1', principal); expect(cancelThreadReminders).toHaveBeenCalledWith({ threadId: 'thread-1', emitMetrics: true }); expect(result).toEqual({ cancelledDb: 2, clearedRuntime: 1 }); @@ -240,7 +242,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/reminders/c const cancelThreadReminders = vi.fn(); const { controller } = await setup({ thread: null, remindersService: { cancelThreadReminders } }); - await expect(controller.cancelThreadReminders('missing-thread')).rejects.toBeInstanceOf(NotFoundException); + await expect(controller.cancelThreadReminders('missing-thread', principal)).rejects.toBeInstanceOf(NotFoundException); expect(cancelThreadReminders).not.toHaveBeenCalled(); }); @@ -250,7 +252,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/reminders/c }); const { controller } = await setup({ remindersService: { cancelThreadReminders } }); - await expect(controller.cancelThreadReminders('thread-1')).rejects.toBeInstanceOf(InternalServerErrorException); + await expect(controller.cancelThreadReminders('thread-1', principal)).rejects.toBeInstanceOf(InternalServerErrorException); }); }); @@ -259,7 +261,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/reminders/c const { controller, reminders } = await setup(); const spy = vi.spyOn(reminders, 'cancelThreadReminders').mockResolvedValue({ cancelledDb: 2, clearedRuntime: 1 }); - const result = await controller.cancelThreadReminders('thread-1'); + const result = await controller.cancelThreadReminders('thread-1', principal); expect(result).toEqual({ cancelledDb: 2, clearedRuntime: 1 }); expect(spy).toHaveBeenCalledWith({ threadId: 'thread-1', emitMetrics: true }); @@ -268,14 +270,14 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/reminders/c it('throws 404 when thread missing', async () => { const { controller } = await setup({ thread: null }); - await expect(controller.cancelThreadReminders('thread-1')).rejects.toBeInstanceOf(NotFoundException); + await expect(controller.cancelThreadReminders('thread-1', principal)).rejects.toBeInstanceOf(NotFoundException); }); it('bubbles errors from service as 500', async () => { const { controller, reminders } = await setup(); vi.spyOn(reminders, 'cancelThreadReminders').mockRejectedValue(new Error('fail')); - await expect(controller.cancelThreadReminders('thread-1')).rejects.toBeInstanceOf( + await expect(controller.cancelThreadReminders('thread-1', principal)).rejects.toBeInstanceOf( InternalServerErrorException, ); }); diff --git a/packages/platform-server/__tests__/agents.threads.controller.send-message.test.ts b/packages/platform-server/__tests__/agents.threads.controller.send-message.test.ts index 5b9959eee..dc7dcb3ba 100644 --- a/packages/platform-server/__tests__/agents.threads.controller.send-message.test.ts +++ b/packages/platform-server/__tests__/agents.threads.controller.send-message.test.ts @@ -11,6 +11,8 @@ import type { ThreadStatus } from '@prisma/client'; import { NotFoundException, ServiceUnavailableException } from '@nestjs/common'; import { RemindersService } from '../src/agents/reminders.service'; +const principal = { userId: 'user-1' } as any; + const runEventsStub = { getRunSummary: async () => null, listRunEvents: async () => ({ items: [], nextCursor: null }), @@ -84,7 +86,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', it('dispatches message to agent runtime when thread is open', async () => { const { controller, invoke, getLatestAgentNodeIdForThread, ensureAssignedAgent } = await setup(); - const result = await controller.sendThreadMessage('thread-1', { text: ' hello world ' }); + const result = await controller.sendThreadMessage('thread-1', { text: ' hello world ' }, principal); expect(result).toEqual({ ok: true }); expect(getLatestAgentNodeIdForThread).not.toHaveBeenCalled(); @@ -98,7 +100,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', it('rejects when message body is invalid', async () => { const { controller } = await setup(); - await expect(controller.sendThreadMessage('thread-1', { text: ' ' })).rejects.toMatchObject({ + await expect(controller.sendThreadMessage('thread-1', { text: ' ' }, principal)).rejects.toMatchObject({ status: 400, response: { error: 'bad_message_payload' }, }); @@ -108,7 +110,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', const { controller } = await setup(); const overLimit = 'a'.repeat(100001); - await expect(controller.sendThreadMessage('thread-1', { text: overLimit })).rejects.toMatchObject({ + await expect(controller.sendThreadMessage('thread-1', { text: overLimit }, principal)).rejects.toMatchObject({ status: 400, response: { error: 'bad_message_payload' }, }); @@ -118,7 +120,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', const { controller } = await setup({ thread: null, latestAgentNodeId: null }); expect.assertions(2); try { - await controller.sendThreadMessage('missing-thread', { text: 'hello' }); + await controller.sendThreadMessage('missing-thread', { text: 'hello' }, principal); throw new Error('expected NotFoundException'); } catch (error) { expect(error).toBeInstanceOf(NotFoundException); @@ -128,7 +130,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', it('rejects when thread is closed', async () => { const { controller } = await setup({ thread: { id: 'thread-1', status: 'closed' as ThreadStatus } }); - await expect(controller.sendThreadMessage('thread-1', { text: 'hello' })).rejects.toMatchObject({ + await expect(controller.sendThreadMessage('thread-1', { text: 'hello' }, principal)).rejects.toMatchObject({ status: 409, response: { error: 'thread_closed' }, }); @@ -141,7 +143,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', }); expect.assertions(2); try { - await controller.sendThreadMessage('thread-1', { text: 'hello' }); + await controller.sendThreadMessage('thread-1', { text: 'hello' }, principal); throw new Error('expected ServiceUnavailableException'); } catch (error) { expect(error).toBeInstanceOf(ServiceUnavailableException); @@ -151,7 +153,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', it('rejects when thread is missing an assigned agent', async () => { const { controller } = await setup({ thread: { id: 'thread-1', status: 'open' as ThreadStatus, assignedAgentNodeId: null } }); - await expect(controller.sendThreadMessage('thread-1', { text: 'hello' })).rejects.toMatchObject({ + await expect(controller.sendThreadMessage('thread-1', { text: 'hello' }, principal)).rejects.toMatchObject({ status: 503, response: { error: 'agent_unavailable' }, }); @@ -162,7 +164,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', const { controller } = await setup({ nodes: [{ id: 'agent-1', template: 'agent', instance: { status: 'not_ready', invoke } }], }); - await expect(controller.sendThreadMessage('thread-1', { text: 'hello' })).rejects.toMatchObject({ + await expect(controller.sendThreadMessage('thread-1', { text: 'hello' }, principal)).rejects.toMatchObject({ status: 503, response: { error: 'agent_unready' }, }); @@ -176,7 +178,7 @@ describe('AgentsThreadsController POST /api/agents/threads/:threadId/messages', templateMeta: { 'custom.agent': { kind: 'agent', title: 'Custom Agent' } }, }); - await controller.sendThreadMessage('thread-1', { text: 'hello meta agent' }); + await controller.sendThreadMessage('thread-1', { text: 'hello meta agent' }, principal); expect(getLatestAgentNodeIdForThread).not.toHaveBeenCalled(); expect(ensureAssignedAgent).not.toHaveBeenCalled(); diff --git a/packages/platform-server/__tests__/agents.threads.controller.terminate.run.test.ts b/packages/platform-server/__tests__/agents.threads.controller.terminate.run.test.ts index f39fb2fd8..e4e215b98 100644 --- a/packages/platform-server/__tests__/agents.threads.controller.terminate.run.test.ts +++ b/packages/platform-server/__tests__/agents.threads.controller.terminate.run.test.ts @@ -10,6 +10,8 @@ import { LiveGraphRuntime } from '../src/graph-core/liveGraph.manager'; import { TemplateRegistry } from '../src/graph-core/templateRegistry'; import { RemindersService } from '../src/agents/reminders.service'; +const principal = { userId: 'user-1' } as any; + const runEventsStub = { getRunSummary: vi.fn(), listRunEvents: vi.fn(), @@ -38,7 +40,7 @@ describe('AgentsThreadsController terminate run endpoint', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const res = await ctrl.terminateRun('run-1'); + const res = await ctrl.terminateRun('run-1', principal); expect(res).toEqual({ ok: true }); expect(activateTerminate).toHaveBeenCalledWith('run-1'); }); @@ -63,7 +65,7 @@ describe('AgentsThreadsController terminate run endpoint', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - const res = await ctrl.terminateRun('run-2'); + const res = await ctrl.terminateRun('run-2', principal); expect(res).toEqual({ ok: true }); expect(activateTerminate).not.toHaveBeenCalled(); }); @@ -87,6 +89,6 @@ describe('AgentsThreadsController terminate run endpoint', () => { }).compile(); const ctrl = await module.resolve(AgentsThreadsController); - await expect(ctrl.terminateRun('missing')).rejects.toBeInstanceOf(NotFoundException); + await expect(ctrl.terminateRun('missing', principal)).rejects.toBeInstanceOf(NotFoundException); }); }); diff --git a/packages/platform-server/__tests__/agents.threads.controller.tool-output.test.ts b/packages/platform-server/__tests__/agents.threads.controller.tool-output.test.ts index 7df1d22b9..f11ee39f4 100644 --- a/packages/platform-server/__tests__/agents.threads.controller.tool-output.test.ts +++ b/packages/platform-server/__tests__/agents.threads.controller.tool-output.test.ts @@ -10,6 +10,8 @@ import { LiveGraphRuntime } from '../src/graph-core/liveGraph.manager'; import { TemplateRegistry } from '../src/graph-core/templateRegistry'; import { RemindersService } from '../src/agents/reminders.service'; +const principal = { userId: 'user-1' } as any; + describe('AgentsThreadsController tool output snapshot endpoint', () => { it('returns 501 when tool output persistence is unavailable', async () => { const runEventsStub = { @@ -19,7 +21,12 @@ describe('AgentsThreadsController tool output snapshot endpoint', () => { const module = await Test.createTestingModule({ controllers: [AgentsThreadsController], providers: [ - { provide: AgentsPersistenceService, useValue: {} }, + { + provide: AgentsPersistenceService, + useValue: { + getRunById: vi.fn(async () => ({ id: 'run-1', threadId: 'thread-1' })), + }, + }, { provide: ThreadCleanupCoordinator, useValue: { closeThreadWithCascade: vi.fn() } }, { provide: RunEventsService, useValue: runEventsStub }, { provide: RunSignalsRegistry, useValue: { register: vi.fn(), activateTerminate: vi.fn(), clear: vi.fn() } }, @@ -32,7 +39,7 @@ describe('AgentsThreadsController tool output snapshot endpoint', () => { const ctrl = await module.resolve(AgentsThreadsController); await expect( - ctrl.getRunEventOutput('run-1', 'event-1', { order: 'asc' } as any), + ctrl.getRunEventOutput('run-1', 'event-1', { order: 'asc' } as any, principal), ).rejects.toThrowError( new NotImplementedException( 'Tool output persistence unavailable. Run `pnpm --filter @agyn/platform-server prisma migrate deploy` followed by `pnpm --filter @agyn/platform-server prisma generate` to install the latest schema.', @@ -56,7 +63,12 @@ describe('AgentsThreadsController tool output snapshot endpoint', () => { const module = await Test.createTestingModule({ controllers: [AgentsThreadsController], providers: [ - { provide: AgentsPersistenceService, useValue: {} }, + { + provide: AgentsPersistenceService, + useValue: { + getRunById: vi.fn(async () => ({ id: 'run-1', threadId: 'thread-1' })), + }, + }, { provide: ThreadCleanupCoordinator, useValue: { closeThreadWithCascade: vi.fn() } }, { provide: RunEventsService, useValue: runEventsStub }, { provide: RunSignalsRegistry, useValue: { register: vi.fn(), activateTerminate: vi.fn(), clear: vi.fn() } }, @@ -68,7 +80,7 @@ describe('AgentsThreadsController tool output snapshot endpoint', () => { const ctrl = await module.resolve(AgentsThreadsController); - const result = await ctrl.getRunEventOutput('run-1', 'event-1', { order: 'asc' } as any); + const result = await ctrl.getRunEventOutput('run-1', 'event-1', { order: 'asc' } as any, principal); expect(result).toBe(snapshot); expect(runEventsStub.getToolOutputSnapshot).toHaveBeenCalledWith({ runId: 'run-1', diff --git a/packages/platform-server/__tests__/agents.threads.filters.test.ts b/packages/platform-server/__tests__/agents.threads.filters.test.ts index e0e38b1ea..d43fdcb47 100644 --- a/packages/platform-server/__tests__/agents.threads.filters.test.ts +++ b/packages/platform-server/__tests__/agents.threads.filters.test.ts @@ -4,6 +4,7 @@ import { StubPrismaService, createPrismaStub } from './helpers/prisma.stub'; import { createRunEventsStub } from './helpers/runEvents.stub'; import { createEventsBusStub } from './helpers/eventsBus.stub'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; +import { createUserServiceStub } from './helpers/userService.stub'; const metricsStub = { getThreadsMetrics: async () => ({}) } as any; const templateRegistryStub = { toSchema: async () => [], getMeta: () => undefined } as any; @@ -42,6 +43,7 @@ describe('AgentsPersistenceService threads filters and updates', () => { createRunEventsStub() as any, createLinkingStub(), eventsBusStub, + createUserServiceStub(), ); // seed const rootOpen = await stub.thread.create({ data: { alias: 'a1', parentId: null, summary: 'A1', status: 'open' } }); diff --git a/packages/platform-server/__tests__/agents.threads.tree.spec.ts b/packages/platform-server/__tests__/agents.threads.tree.spec.ts index 4d0ef9b89..c4e7fb0d2 100644 --- a/packages/platform-server/__tests__/agents.threads.tree.spec.ts +++ b/packages/platform-server/__tests__/agents.threads.tree.spec.ts @@ -4,6 +4,7 @@ import { StubPrismaService, createPrismaStub } from './helpers/prisma.stub'; import { createRunEventsStub } from './helpers/runEvents.stub'; import { createEventsBusStub } from './helpers/eventsBus.stub'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; +import { createUserServiceStub } from './helpers/userService.stub'; const metricsStub = { getThreadsMetrics: async () => ({}) } as any; const templateRegistryStub = { toSchema: async () => [], getMeta: () => undefined } as any; @@ -40,6 +41,7 @@ function createService() { createRunEventsStub() as any, createLinkingStub(), createEventsBusStub(), + createUserServiceStub(), ); return { prismaStub, svc }; } diff --git a/packages/platform-server/__tests__/agents/agents.persistence.service.spec.ts b/packages/platform-server/__tests__/agents/agents.persistence.service.spec.ts index 6ce60265b..6584b681b 100644 --- a/packages/platform-server/__tests__/agents/agents.persistence.service.spec.ts +++ b/packages/platform-server/__tests__/agents/agents.persistence.service.spec.ts @@ -10,6 +10,7 @@ import type { RunEventsService } from '../../src/events/run-events.service'; import type { CallAgentLinkingService } from '../../src/agents/call-agent-linking.service'; import type { EventsBusService } from '../../src/events/events-bus.service'; import { HumanMessage } from '@agyn/llm'; +import { createUserServiceStub } from '../helpers/userService.stub'; describe('AgentsPersistenceService', () => { it('persists invocation messages as user role', async () => { @@ -21,13 +22,19 @@ describe('AgentsPersistenceService', () => { createdAt: new Date('2024-01-01T00:00:00Z'), })); + const threadRepository = { + findUnique: vi.fn().mockResolvedValue({ id: 'thread-1', ownerUserId: 'user-default' }), + }; + const txClient = { + thread: threadRepository, run: { create: vi.fn().mockResolvedValue({ id: 'run-1', threadId: 'thread-1', status: 'running' }) }, message: { create: messageCreate }, runMessage: { create: vi.fn().mockResolvedValue(undefined) }, }; const prismaClient = { + thread: threadRepository, $transaction: vi.fn(async (fn: (tx: typeof txClient) => Promise) => await fn(txClient)), }; @@ -47,6 +54,7 @@ describe('AgentsPersistenceService', () => { emitThreadMetrics: vi.fn(), publishEvent: vi.fn().mockResolvedValue(undefined), } as unknown as EventsBusService; + const userService = createUserServiceStub(); const service = new AgentsPersistenceService( prismaService, @@ -56,6 +64,7 @@ describe('AgentsPersistenceService', () => { runEvents, callAgentLinking, eventsBus, + userService, ); await service.beginRunThread('thread-1', [HumanMessage.fromText('Hello there')]); diff --git a/packages/platform-server/__tests__/call_agent.parentId.integration.test.ts b/packages/platform-server/__tests__/call_agent.parentId.integration.test.ts index d1277324b..64afc54d6 100644 --- a/packages/platform-server/__tests__/call_agent.parentId.integration.test.ts +++ b/packages/platform-server/__tests__/call_agent.parentId.integration.test.ts @@ -7,6 +7,7 @@ import { createPrismaStub, StubPrismaService } from './helpers/prisma.stub'; import { createRunEventsStub } from './helpers/runEvents.stub'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; import { createEventsBusStub } from './helpers/eventsBus.stub'; +import { createUserServiceStub } from './helpers/userService.stub'; const metricsStub = { getThreadsMetrics: async () => ({}) } as any; const templateRegistryStub = { toSchema: async () => [], getMeta: () => undefined } as any; @@ -50,6 +51,7 @@ describe('call_agent integration: creates child thread with parentId', () => { onChildRunCompleted: async () => null, } as unknown as CallAgentLinkingService, eventsBus as any, + createUserServiceStub(), ); const linking = { buildInitialMetadata: (params: { tool: 'call_agent' | 'call_engineer'; parentThreadId: string; childThreadId: string }) => ({ diff --git a/packages/platform-server/__tests__/call_agent.timeline.metadata.integration.test.ts b/packages/platform-server/__tests__/call_agent.timeline.metadata.integration.test.ts index 3a5e86e36..598f83e10 100644 --- a/packages/platform-server/__tests__/call_agent.timeline.metadata.integration.test.ts +++ b/packages/platform-server/__tests__/call_agent.timeline.metadata.integration.test.ts @@ -10,6 +10,7 @@ import type { TemplateRegistry } from '../src/graph-core/templateRegistry'; import type { GraphRepository } from '../src/graph/graph.repository'; import { HumanMessage, SystemMessage, AIMessage } from '@agyn/llm'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; +import { UserService } from '../src/auth/user.service'; const databaseUrl = process.env.AGENTS_DATABASE_URL; const shouldRunDbTests = process.env.RUN_DB_TESTS === 'true' && !!databaseUrl; @@ -31,6 +32,7 @@ if (!shouldRunDbTests) { const runEvents = new RunEventsService(prismaService); const eventsBus = new EventsBusService(runEvents); const callAgentLinking = new CallAgentLinkingService(prismaService, runEvents, eventsBus); + const userService = new UserService(prismaService); const agents = new AgentsPersistenceService( prismaService, metricsStub, @@ -39,6 +41,7 @@ if (!shouldRunDbTests) { runEvents, callAgentLinking, eventsBus, + userService, ); async function createCallAgentParentEvent(parentThreadId: string, childThreadId: string, runId: string) { diff --git a/packages/platform-server/__tests__/call_agent.tool.test.ts b/packages/platform-server/__tests__/call_agent.tool.test.ts index 111975735..21f617162 100644 --- a/packages/platform-server/__tests__/call_agent.tool.test.ts +++ b/packages/platform-server/__tests__/call_agent.tool.test.ts @@ -7,6 +7,7 @@ import { createRunEventsStub } from './helpers/runEvents.stub'; import { Signal } from '../src/signal'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; import { createEventsBusStub } from './helpers/eventsBus.stub'; +import { createUserServiceStub } from './helpers/userService.stub'; const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms)); @@ -48,6 +49,7 @@ const createPersistence = (linking?: CallAgentLinkingService) => { createRunEventsStub() as any, linking ?? createLinkingStub().instance, eventsBusStub as any, + createUserServiceStub(), ); return svc; }; @@ -75,6 +77,7 @@ describe('CallAgentTool unit', () => { it('calls attached agent and returns its response.text', async () => { const { instance: linking, spies } = createLinkingStub(); const persistence = createPersistence(linking); + const parentThreadId = await persistence.getOrCreateThreadByAlias('unit', 'parent-sync', 'Parent thread'); const tool = new CallAgentTool(persistence, linking); await tool.setConfig({ description: 'desc', response: 'sync' }); const agent = new FakeAgent(async (_thread, _msgs) => { @@ -86,7 +89,7 @@ describe('CallAgentTool unit', () => { const dynamic = tool.getTool(); const out = await dynamic.execute( { input: 'ping', threadAlias: 'sub', summary: 'sub summary' }, - { threadId: 't2', runId: 'r', finishSignal: new Signal(), terminateSignal: new Signal(), callerAgent: {} }, + { threadId: parentThreadId, runId: 'r', finishSignal: new Signal(), terminateSignal: new Signal(), callerAgent: {} }, ); expect(out).toBe('OK'); expect(spies.registerParentToolExecution).toHaveBeenCalledTimes(1); @@ -106,6 +109,7 @@ describe('CallAgentTool unit', () => { it('resolves subthread by alias under parent UUID', async () => { const { instance: linking } = createLinkingStub(); const persistence = createPersistence(linking); + const parentThreadId = await persistence.getOrCreateThreadByAlias('unit', 'parent-alias', 'Parent thread'); const tool = new CallAgentTool(persistence, linking); await tool.setConfig({ description: 'desc', response: 'sync' }); const agent = new FakeAgent(async (_thread, _msgs) => { @@ -117,14 +121,16 @@ describe('CallAgentTool unit', () => { const dynamic = tool.getTool(); const out = await dynamic.execute( { input: 'ping', threadAlias: 'sub', summary: 'sub summary' }, - { threadId: 'parent', runId: 'r', finishSignal: new Signal(), terminateSignal: new Signal(), callerAgent: {} }, + { threadId: parentThreadId, runId: 'r', finishSignal: new Signal(), terminateSignal: new Signal(), callerAgent: {} }, ); expect(out).toBe('OK'); }); it('async mode returns sent immediately', async () => { const { instance: linking } = createLinkingStub(); - const tool = new CallAgentTool(createPersistence(linking), linking); + const persistence = createPersistence(linking); + const parentThreadId = await persistence.getOrCreateThreadByAlias('unit', 'parent-async', 'Parent thread'); + const tool = new CallAgentTool(persistence, linking); await tool.setConfig({ description: 'desc', response: 'async' }); const child = new FakeAgent(async (thread, msgs) => { expect(msgs[0]?.text).toBe('do work'); @@ -136,7 +142,7 @@ describe('CallAgentTool unit', () => { const dynamic = tool.getTool(); const res = await dynamic.execute( { input: 'do work', threadAlias: 'c1', summary: 'c1 summary' }, - { threadId: 'p', runId: 'r', finishSignal: new Signal(), terminateSignal: new Signal(), callerAgent: {} }, + { threadId: parentThreadId, runId: 'r', finishSignal: new Signal(), terminateSignal: new Signal(), callerAgent: {} }, ); expect(typeof res).toBe('string'); expect(JSON.parse(res).status).toBe('sent'); @@ -144,7 +150,9 @@ describe('CallAgentTool unit', () => { it('ignore mode returns sent and does not trigger parent', async () => { const { instance: linking } = createLinkingStub(); - const tool = new CallAgentTool(createPersistence(linking), linking); + const persistence = createPersistence(linking); + const parentThreadId = await persistence.getOrCreateThreadByAlias('unit', 'parent-ignore', 'Parent thread'); + const tool = new CallAgentTool(persistence, linking); await tool.setConfig({ description: 'desc', response: 'ignore' }); const child = new FakeAgent(async () => { const ai = AIMessage.fromText('ignored'); @@ -155,7 +163,7 @@ describe('CallAgentTool unit', () => { const dynamic = tool.getTool(); const res = await dynamic.execute( { input: 'do work', threadAlias: 'c2', summary: 'c2 summary' }, - { threadId: 'p2', runId: 'r', finishSignal: new Signal(), terminateSignal: new Signal(), callerAgent: {} }, + { threadId: parentThreadId, runId: 'r', finishSignal: new Signal(), terminateSignal: new Signal(), callerAgent: {} }, ); expect(typeof res).toBe('string'); expect(JSON.parse(res).status).toBe('sent'); diff --git a/packages/platform-server/__tests__/graph.socket.gateway.bus.test.ts b/packages/platform-server/__tests__/graph.socket.gateway.bus.test.ts index 0fcb7e70c..a71b876ac 100644 --- a/packages/platform-server/__tests__/graph.socket.gateway.bus.test.ts +++ b/packages/platform-server/__tests__/graph.socket.gateway.bus.test.ts @@ -2,6 +2,8 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; import type { EventsBusService, ReminderCountEvent, RunEventBusPayload } from '../src/events/events-bus.service'; import type { ToolOutputChunkPayload, ToolOutputTerminalPayload } from '../src/events/run-events.service'; import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; +import type { ConfigService } from '../src/core/services/config.service'; +import type { AuthService } from '../src/auth/auth.service'; type Handler = ((payload: T) => void) | null; @@ -13,10 +15,14 @@ type GatewayTestContext = { terminal: Handler; reminder: Handler; nodeState: Handler<{ nodeId: string; state: Record; updatedAtMs?: number }>; - threadCreated: Handler<{ id: string }>; - threadUpdated: Handler<{ id: string }>; - messageCreated: Handler<{ threadId: string; message: { id: string } }>; - runStatus: Handler<{ threadId: string; run: { id: string; status: string; createdAt: Date; updatedAt: Date } }>; + threadCreated: Handler<{ id: string; ownerUserId: string }>; + threadUpdated: Handler<{ id: string; ownerUserId: string }>; + messageCreated: Handler<{ threadId: string; ownerUserId: string; message: { id: string } }>; + runStatus: Handler<{ + threadId: string; + ownerUserId: string; + run: { id: string; status: string; createdAt: Date; updatedAt: Date }; + }>; threadMetrics: Handler<{ threadId: string }>; threadMetricsAncestors: Handler<{ threadId: string }>; }; @@ -116,7 +122,9 @@ function createGatewayTestContext(): GatewayTestContext { const metrics = { getThreadsMetrics: vi.fn().mockResolvedValue({}) } as any; const prisma = { getClient: vi.fn().mockReturnValue({ $queryRaw: vi.fn().mockResolvedValue([]) }) } as any; - const gateway = new GraphSocketGateway(runtime, metrics, prisma, eventsBus as EventsBusService); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authStub = { resolvePrincipalFromCookieHeader: vi.fn() } as unknown as AuthService; + const gateway = new GraphSocketGateway(runtime, metrics, prisma, eventsBus as EventsBusService, configStub, authStub); const internalLogger = (gateway as unknown as { logger: { warn: (...args: unknown[]) => void; error: (...args: unknown[]) => void; log: (...args: unknown[]) => void; debug: (...args: unknown[]) => void } }).logger; const logger = { warn: vi.spyOn(internalLogger, 'warn').mockImplementation(() => undefined), @@ -262,6 +270,7 @@ describe('GraphSocketGateway event bus integration', () => { createdAt: new Date(), parentId: null, channelNodeId: null, + ownerUserId: 'user-1', } as any); ctx.handlers.threadUpdated?.({ id: 'thread-2', @@ -271,14 +280,27 @@ describe('GraphSocketGateway event bus integration', () => { createdAt: new Date(), parentId: null, channelNodeId: null, + ownerUserId: 'user-1', } as any); - ctx.handlers.messageCreated?.({ threadId: 'thread-1', message: { id: 'msg-1', kind: 'user', text: 'hi', source: {}, createdAt: new Date() } as any }); - ctx.handlers.runStatus?.({ threadId: 'thread-1', run: { id: 'run-1', status: 'running', createdAt: new Date(), updatedAt: new Date() } }); + ctx.handlers.messageCreated?.({ + threadId: 'thread-1', + ownerUserId: 'user-1', + message: { id: 'msg-1', kind: 'user', text: 'hi', source: {}, createdAt: new Date() } as any, + }); + ctx.handlers.runStatus?.({ + threadId: 'thread-1', + ownerUserId: 'user-1', + run: { id: 'run-1', status: 'running', createdAt: new Date(), updatedAt: new Date() }, + }); expect(threadCreated).toHaveBeenCalled(); expect(threadUpdated).toHaveBeenCalled(); - expect(messageCreated).toHaveBeenCalledWith('thread-1', expect.objectContaining({ id: 'msg-1' })); - expect(runStatus).toHaveBeenCalledWith('thread-1', expect.objectContaining({ id: 'run-1' })); + expect(messageCreated).toHaveBeenCalledWith('thread-1', 'user-1', expect.objectContaining({ id: 'msg-1' })); + expect(runStatus).toHaveBeenCalledWith({ + threadId: 'thread-1', + ownerUserId: expect.any(String), + run: expect.objectContaining({ id: 'run-1' }), + }); }); it('schedules metrics for thread_metrics events', () => { diff --git a/packages/platform-server/__tests__/helpers/prisma.stub.ts b/packages/platform-server/__tests__/helpers/prisma.stub.ts index 777248a50..9da69ede3 100644 --- a/packages/platform-server/__tests__/helpers/prisma.stub.ts +++ b/packages/platform-server/__tests__/helpers/prisma.stub.ts @@ -11,6 +11,7 @@ export function createPrismaStub() { channel: any; channelNodeId: string | null; assignedAgentNodeId: string | null; + ownerUserId: string; }> = []; const runs: Array<{ id: string; threadId: string; status: string; createdAt: Date; updatedAt: Date }> = []; const messages: Array<{ id: string; kind: string; text: string | null; source: any; createdAt: Date }> = []; @@ -48,6 +49,7 @@ export function createPrismaStub() { channel: data.channel ?? null, channelNodeId: data.channelNodeId ?? null, assignedAgentNodeId: data.assignedAgentNodeId ?? null, + ownerUserId: data.ownerUserId ?? 'user-default', }; threads.push(row); return row; @@ -61,6 +63,7 @@ export function createPrismaStub() { if (Object.prototype.hasOwnProperty.call(data, 'channel')) next.channel = data.channel ?? null; if (Object.prototype.hasOwnProperty.call(data, 'channelNodeId')) next.channelNodeId = data.channelNodeId ?? null; if (Object.prototype.hasOwnProperty.call(data, 'assignedAgentNodeId')) next.assignedAgentNodeId = data.assignedAgentNodeId ?? null; + if (Object.prototype.hasOwnProperty.call(data, 'ownerUserId')) next.ownerUserId = data.ownerUserId ?? 'user-default'; threads[idx] = next as any; return threads[idx]; }, @@ -74,12 +77,18 @@ export function createPrismaStub() { (where.assignedAgentNodeId === null ? t.assignedAgentNodeId !== null : t.assignedAgentNodeId !== where.assignedAgentNodeId) ) continue; + if (Object.prototype.hasOwnProperty.call(where, 'ownerUserId') && t.ownerUserId !== where.ownerUserId) continue; if (Object.prototype.hasOwnProperty.call(data, 'summary')) t.summary = data.summary ?? null; if (Object.prototype.hasOwnProperty.call(data, 'assignedAgentNodeId')) t.assignedAgentNodeId = data.assignedAgentNodeId ?? null; + if (Object.prototype.hasOwnProperty.call(data, 'ownerUserId')) t.ownerUserId = data.ownerUserId ?? 'user-default'; count += 1; } return { count }; }, + findFirst: async ({ where, orderBy, select }: any) => { + const rows = await prisma.thread.findMany({ where, orderBy, select }); + return rows[0] ?? null; + }, findMany: async (args: any) => { let rows = [...threads]; const where = args?.where || {}; @@ -90,6 +99,7 @@ export function createPrismaStub() { rows = rows.filter((t) => ids.has(t.parentId)); } if (where.status) rows = rows.filter((t) => t.status === where.status); + if (where.ownerUserId) rows = rows.filter((t) => t.ownerUserId === where.ownerUserId); if (args?.orderBy?.createdAt === 'desc') rows.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime()); const take = args?.take; const selected = rows.slice(0, take || rows.length); @@ -114,6 +124,7 @@ export function createPrismaStub() { rows = rows.filter((t) => ids.has(t.parentId)); } if (where?.status) rows = rows.filter((t) => t.status === where.status); + if (where?.ownerUserId) rows = rows.filter((t) => t.ownerUserId === where.ownerUserId); const grouped = new Map(); for (const row of rows) { const key = row.parentId ?? null; diff --git a/packages/platform-server/__tests__/helpers/userService.stub.ts b/packages/platform-server/__tests__/helpers/userService.stub.ts new file mode 100644 index 000000000..e9e8600e6 --- /dev/null +++ b/packages/platform-server/__tests__/helpers/userService.stub.ts @@ -0,0 +1,8 @@ +import type { UserService } from '../../src/auth/user.service'; + +export function createUserServiceStub(overrides?: Partial): UserService { + const base = { + ensureDefaultUser: async () => ({ id: 'user-default' }), + } as Partial; + return { ...base, ...overrides } as UserService; +} diff --git a/packages/platform-server/__tests__/memory.controller.test.ts b/packages/platform-server/__tests__/memory.controller.test.ts index f2af09890..fcb8ff8cd 100644 --- a/packages/platform-server/__tests__/memory.controller.test.ts +++ b/packages/platform-server/__tests__/memory.controller.test.ts @@ -3,6 +3,7 @@ import { MemoryController } from '../src/graph/controllers/memory.controller'; import { HttpException } from '@nestjs/common'; import type { MemoryScope } from '../src/nodes/memory/memory.types'; import type { MemoryService } from '../src/nodes/memory/memory.service'; +import type { AgentsPersistenceService } from '../src/agents/agents.persistence.service'; type MemoryServiceStub = { listDocs: ReturnType; @@ -30,11 +31,16 @@ const createServiceStub = (): MemoryServiceStub => ({ describe('MemoryController', () => { let service: MemoryServiceStub; + let persistence: { getThreadById: ReturnType }; let controller: MemoryController; + const principal = { userId: 'user-1' } as any; beforeEach(() => { service = createServiceStub(); - controller = new MemoryController(service as unknown as MemoryService); + persistence = { + getThreadById: vi.fn(async () => ({ id: 'thread-1', ownerUserId: principal.userId })), + } satisfies Pick; + controller = new MemoryController(service as unknown as MemoryService, persistence as AgentsPersistenceService); }); it('listDocs returns service payload', async () => { @@ -43,9 +49,30 @@ describe('MemoryController', () => { { nodeId: 'b', scope: 'perThread' as MemoryScope, threadId: 'thread-1' }, ]); - const result = await controller.listDocs(); + const result = await controller.listDocs(principal); expect(service.listDocs).toHaveBeenCalledTimes(1); + expect(persistence.getThreadById).toHaveBeenCalledWith('thread-1', { ownerUserId: principal.userId }); + expect(result).toEqual({ + items: [ + { nodeId: 'a', scope: 'global' }, + { nodeId: 'b', scope: 'perThread', threadId: 'thread-1' }, + ], + }); + }); + + it('listDocs filters out threads not owned by principal', async () => { + service.listDocs.mockResolvedValue([ + { nodeId: 'a', scope: 'global' as MemoryScope }, + { nodeId: 'b', scope: 'perThread' as MemoryScope, threadId: 'thread-1' }, + { nodeId: 'b', scope: 'perThread' as MemoryScope, threadId: 'thread-2' }, + ]); + persistence.getThreadById.mockImplementation(async (threadId: string) => + threadId === 'thread-1' ? { id: 'thread-1', ownerUserId: principal.userId } : null, + ); + + const result = await controller.listDocs(principal); + expect(result).toEqual({ items: [ { nodeId: 'a', scope: 'global' }, @@ -56,7 +83,12 @@ describe('MemoryController', () => { it('append requires thread id for perThread scope', async () => { await expect( - controller.append({ nodeId: 'node', scope: 'perThread' } as any, { path: '/note.txt', data: 'hello' } as any, {} as any), + controller.append( + { nodeId: 'node', scope: 'perThread' } as any, + { path: '/note.txt', data: 'hello' } as any, + {} as any, + principal, + ), ).rejects.toBeInstanceOf(HttpException); expect(service.append).not.toHaveBeenCalled(); }); @@ -68,15 +100,17 @@ describe('MemoryController', () => { { nodeId: 'node', scope: 'perThread' } as any, { path: '/note.txt', data: 'hello', threadId: ' thread-1 ' } as any, {} as any, + principal, ); + expect(persistence.getThreadById).toHaveBeenCalledWith('thread-1', { ownerUserId: principal.userId }); expect(service.append).toHaveBeenCalledWith('node', 'perThread', 'thread-1', '/note.txt', 'hello'); }); it('list forwards defaults and thread resolution for global scope', async () => { service.list.mockResolvedValue([{ name: 'logs', hasSubdocs: true }]); - const result = await controller.list({ nodeId: 'node', scope: 'global' } as any, {} as any); + const result = await controller.list({ nodeId: 'node', scope: 'global' } as any, {} as any, principal); expect(service.list).toHaveBeenCalledWith('node', 'global', undefined, '/'); expect(result).toEqual({ items: [{ name: 'logs', hasSubdocs: true }] }); @@ -84,14 +118,14 @@ describe('MemoryController', () => { it('read passes through content', async () => { service.read.mockResolvedValueOnce('hello world'); - const ok = await controller.read({ nodeId: 'node', scope: 'global' } as any, { path: '/note.txt' } as any); + const ok = await controller.read({ nodeId: 'node', scope: 'global' } as any, { path: '/note.txt' } as any, principal); expect(ok).toEqual({ content: 'hello world' }); }); it('read allows root path', async () => { service.read.mockResolvedValueOnce(''); - const ok = await controller.read({ nodeId: 'node', scope: 'global' } as any, { path: '/' } as any); + const ok = await controller.read({ nodeId: 'node', scope: 'global' } as any, { path: '/' } as any, principal); expect(service.read).toHaveBeenCalledWith('node', 'global', undefined, '/'); expect(ok).toEqual({ content: '' }); @@ -100,7 +134,7 @@ describe('MemoryController', () => { it('read maps ENOENT to 404', async () => { service.read.mockRejectedValueOnce(new Error('ENOENT: missing')); - await expect(controller.read({ nodeId: 'node', scope: 'global' } as any, { path: '/missing' } as any)).rejects.toSatisfy((err) => { + await expect(controller.read({ nodeId: 'node', scope: 'global' } as any, { path: '/missing' } as any, principal)).rejects.toSatisfy((err) => { expect(err).toBeInstanceOf(HttpException); expect((err as HttpException).getStatus()).toBe(404); return true; @@ -113,6 +147,7 @@ describe('MemoryController', () => { { nodeId: 'node', scope: 'perThread' } as any, { path: '/note.txt', oldStr: 'a', newStr: 'b', threadId: ' thread-1 ' } as any, {} as any, + principal, ); expect(ok).toEqual({ replaced: 2 }); expect(service.update).toHaveBeenCalledWith('node', 'perThread', 'thread-1', '/note.txt', 'a', 'b'); @@ -123,6 +158,7 @@ describe('MemoryController', () => { { nodeId: 'node', scope: 'perThread' } as any, { path: '/note.txt', oldStr: 'a', newStr: 'b', threadId: 'thread-1' } as any, {} as any, + principal, ), ).rejects.toSatisfy((err) => { expect(err).toBeInstanceOf(HttpException); @@ -134,7 +170,7 @@ describe('MemoryController', () => { it('delete delegates to service', async () => { service.delete.mockResolvedValue({ removed: 1 }); - const result = await controller.remove({ nodeId: 'node', scope: 'global' } as any, { path: '/file.txt' } as any); + const result = await controller.remove({ nodeId: 'node', scope: 'global' } as any, { path: '/file.txt' } as any, principal); expect(service.delete).toHaveBeenCalledWith('node', 'global', undefined, '/file.txt'); expect(result).toEqual({ removed: 1 }); @@ -144,9 +180,18 @@ describe('MemoryController', () => { const payload = { nodeId: 'node', scope: 'perThread' as MemoryScope, threadId: 'thread-1', data: {}, dirs: {} }; service.dump.mockResolvedValue(payload); - const result = await controller.dump({ nodeId: 'node', scope: 'perThread' } as any, { threadId: ' thread-1 ' } as any); + const result = await controller.dump({ nodeId: 'node', scope: 'perThread' } as any, { threadId: ' thread-1 ' } as any, principal); expect(service.dump).toHaveBeenCalledWith('node', 'perThread', 'thread-1'); expect(result).toBe(payload); }); + + it('rejects per-thread requests for threads not owned by principal', async () => { + persistence.getThreadById.mockResolvedValue(null); + + await expect( + controller.read({ nodeId: 'node', scope: 'perThread' } as any, { path: '/note.txt', threadId: 't-2' } as any, principal), + ).rejects.toMatchObject({ status: 404 }); + expect(service.read).not.toHaveBeenCalled(); + }); }); diff --git a/packages/platform-server/__tests__/run-events.publish.test.ts b/packages/platform-server/__tests__/run-events.publish.test.ts index 3d8bc2774..9d756f20e 100644 --- a/packages/platform-server/__tests__/run-events.publish.test.ts +++ b/packages/platform-server/__tests__/run-events.publish.test.ts @@ -2,6 +2,8 @@ import { describe, it, expect, beforeEach, afterAll, afterEach, vi } from 'vites import { PrismaClient, ToolExecStatus } from '@prisma/client'; import { randomUUID } from 'node:crypto'; import type { PrismaService } from '../src/core/services/prisma.service'; +import type { ConfigService } from '../src/core/services/config.service'; +import type { AuthService } from '../src/auth/auth.service'; import { RunEventsService, type RunTimelineEvent } from '../src/events/run-events.service'; import { EventsBusService } from '../src/events/events-bus.service'; import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; @@ -37,7 +39,9 @@ maybeDescribe('RunEventsService publishEvent broadcasting', () => { const runtime = { subscribe: vi.fn() } as any; const metrics = { getThreadsMetrics: vi.fn().mockResolvedValue({}) } as any; const prismaStub = { getClient: vi.fn().mockReturnValue({ $queryRaw: vi.fn().mockResolvedValue([]) }) } as any; - gateway = new GraphSocketGateway(runtime, metrics, prismaStub, eventsBus); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authServiceStub = { resolvePrincipalFromCookieHeader: vi.fn() } as unknown as AuthService; + gateway = new GraphSocketGateway(runtime, metrics, prismaStub, eventsBus, configStub, authServiceStub); emitRunEventSpy = vi.spyOn(gateway, 'emitRunEvent'); await gateway.onModuleInit(); }); diff --git a/packages/platform-server/__tests__/socket.events.test.ts b/packages/platform-server/__tests__/socket.events.test.ts index 6a2a597d8..4831d767f 100644 --- a/packages/platform-server/__tests__/socket.events.test.ts +++ b/packages/platform-server/__tests__/socket.events.test.ts @@ -3,6 +3,8 @@ import { FastifyAdapter } from '@nestjs/platform-fastify'; import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; import { PrismaService } from '../src/core/services/prisma.service'; import { ThreadsMetricsService } from '../src/agents/threads.metrics.service'; +import type { ConfigService } from '../src/core/services/config.service'; +import type { AuthService } from '../src/auth/auth.service'; import Node from '../src/nodes/base/Node'; // Minimal Test Node to trigger status changes @@ -31,7 +33,9 @@ describe('Socket events', () => { subscribeToThreadMetrics: () => () => {}, subscribeToThreadMetricsAncestors: () => () => {}, }; - const gateway = new GraphSocketGateway(runtimeStub, metrics, prismaStub, eventsBusStub as any); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authStub = { resolvePrincipalFromCookieHeader: async () => ({ userId: 'test-user' }) } as unknown as AuthService; + const gateway = new GraphSocketGateway(runtimeStub, metrics, prismaStub, eventsBusStub as any, configStub, authStub); gateway.init({ server: fastify.server }); const emitMap = new Map>(); @@ -81,7 +85,9 @@ describe('Socket events', () => { subscribeToThreadMetrics: () => () => {}, subscribeToThreadMetricsAncestors: () => () => {}, }; - const gateway = new GraphSocketGateway(runtimeStub, metrics, prismaStub, eventsBusStub as any); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authStub = { resolvePrincipalFromCookieHeader: async () => ({ userId: 'test-user' }) } as unknown as AuthService; + const gateway = new GraphSocketGateway(runtimeStub, metrics, prismaStub, eventsBusStub as any, configStub, authStub); gateway.init({ server: fastify.server }); const emitMap = new Map>(); const toSpy = vi.fn((room: string) => { @@ -115,7 +121,9 @@ describe('Socket events', () => { subscribeToThreadMetrics: () => () => {}, subscribeToThreadMetricsAncestors: () => () => {}, }; - const gateway = new GraphSocketGateway(runtimeStub, metrics, prismaStub, eventsBusStub as any); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authStub = { resolvePrincipalFromCookieHeader: async () => ({ userId: 'test-user' }) } as unknown as AuthService; + const gateway = new GraphSocketGateway(runtimeStub, metrics, prismaStub, eventsBusStub as any, configStub, authStub); gateway.init({ server: fastify.server }); const emitMap = new Map>(); const toSpy = vi.fn((room: string) => { diff --git a/packages/platform-server/__tests__/socket.gateway.test.ts b/packages/platform-server/__tests__/socket.gateway.test.ts index cf1efbb9f..fb6cf70fd 100644 --- a/packages/platform-server/__tests__/socket.gateway.test.ts +++ b/packages/platform-server/__tests__/socket.gateway.test.ts @@ -3,6 +3,8 @@ import { FastifyAdapter } from '@nestjs/platform-fastify'; import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; import { PrismaService } from '../src/core/services/prisma.service'; import { ThreadsMetricsService } from '../src/agents/threads.metrics.service'; +import type { ConfigService } from '../src/core/services/config.service'; +import type { AuthService } from '../src/auth/auth.service'; describe('GraphSocketGateway', () => { it('gateway initializes without errors', async () => { @@ -24,7 +26,9 @@ describe('GraphSocketGateway', () => { subscribeToThreadMetrics: () => () => {}, subscribeToThreadMetricsAncestors: () => () => {}, }; - const gateway = new GraphSocketGateway(runtimeStub, metrics, prismaStub, eventsBusStub as any); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authStub = { resolvePrincipalFromCookieHeader: async () => ({ userId: 'test-user' }) } as unknown as AuthService; + const gateway = new GraphSocketGateway(runtimeStub, metrics, prismaStub, eventsBusStub as any, configStub, authStub); expect(() => gateway.init({ server: fastify.server })).not.toThrow(); }); }); diff --git a/packages/platform-server/__tests__/socket.metrics.coalesce.test.ts b/packages/platform-server/__tests__/socket.metrics.coalesce.test.ts index ed6c42b33..ab1916914 100644 --- a/packages/platform-server/__tests__/socket.metrics.coalesce.test.ts +++ b/packages/platform-server/__tests__/socket.metrics.coalesce.test.ts @@ -1,6 +1,8 @@ import { describe, it, expect, vi } from 'vitest'; import { FastifyAdapter } from '@nestjs/platform-fastify'; import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; +import type { ConfigService } from '../src/core/services/config.service'; +import type { AuthService } from '../src/auth/auth.service'; describe('GraphSocketGateway metrics coalescing', () => { it('coalesces multiple schedules into single batch computation', async () => { @@ -13,7 +15,18 @@ describe('GraphSocketGateway metrics coalescing', () => { Object.fromEntries(_ids.map((id) => [id, { remindersCount: 0, containersCount: 0, activity: 'idle' as const }])), ); const metricsStub = { getThreadsMetrics } as any; - const prismaStub = { getClient: () => ({ $queryRaw: async () => [] }) } as any; + const prismaStub = { + getClient: () => ({ + $queryRaw: async () => [], + thread: { + findUnique: async ({ where }: { where: { id: string } }) => { + if (where.id === 't1') return { ownerUserId: 'user-1' }; + if (where.id === 't2') return { ownerUserId: 'user-2' }; + return null; + }, + }, + }), + } as any; const eventsBusStub = { subscribeToRunEvents: () => () => {}, subscribeToToolOutputChunk: () => () => {}, @@ -27,25 +40,38 @@ describe('GraphSocketGateway metrics coalescing', () => { subscribeToThreadMetrics: () => () => {}, subscribeToThreadMetricsAncestors: () => () => {}, }; - const gateway = new GraphSocketGateway(runtimeStub, metricsStub, prismaStub, eventsBusStub as any); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authStub = { resolvePrincipalFromCookieHeader: async () => ({ userId: 'test-user' }) } as unknown as AuthService; + const gateway = new GraphSocketGateway(runtimeStub, metricsStub, prismaStub, eventsBusStub as any, configStub, authStub); // Attach and stub io emit sink gateway.init({ server: fastify.server }); - const captured: Array<{ room: string; event: string; payload: any }> = []; - (gateway as any)['io'] = { to: (room: string) => ({ emit: (event: string, payload: any) => { captured.push({ room, event, payload }); } }) }; + const ownerSpy = vi.spyOn(gateway as any, 'getThreadOwnerId').mockImplementation(async (threadId: string) => { + if (threadId === 't1') return 'user-1'; + if (threadId === 't2') return 'user-2'; + return null; + }); + const emitSpy = vi.spyOn(gateway as any, 'emitToUserRooms').mockImplementation(() => {}); gateway.scheduleThreadMetrics('t1'); gateway.scheduleThreadMetrics('t2'); // Advance timers to trigger flush vi.advanceTimersByTime(120); - await Promise.resolve(); + await vi.runOnlyPendingTimersAsync(); // Assert single batch computation and grouped emits to both rooms expect(getThreadsMetrics).toHaveBeenCalledTimes(1); expect(getThreadsMetrics).toHaveBeenCalledWith(['t1', 't2']); - const activityThreadsRoom = captured.filter((e) => e.event === 'thread_activity_changed' && e.room === 'threads'); - const remindersThreadsRoom = captured.filter((e) => e.event === 'thread_reminders_count' && e.room === 'threads'); - expect(activityThreadsRoom.map((e) => e.payload.threadId).sort()).toEqual(['t1', 't2']); - expect(remindersThreadsRoom.map((e) => e.payload.threadId).sort()).toEqual(['t1', 't2']); + const activityThreads = emitSpy.mock.calls + .filter(([, , event]) => event === 'thread_activity_changed') + .map(([, , , payload]) => (payload as { threadId: string }).threadId) + .sort(); + const remindersThreads = emitSpy.mock.calls + .filter(([, , event]) => event === 'thread_reminders_count') + .map(([, , , payload]) => (payload as { threadId: string }).threadId) + .sort(); + expect(activityThreads).toEqual(['t1', 't2']); + expect(remindersThreads).toEqual(['t1', 't2']); + expect(ownerSpy.mock.calls.map((args) => args[0]).sort()).toEqual(['t1', 't2']); vi.useRealTimers(); }); }); diff --git a/packages/platform-server/__tests__/socket.node_status.integration.test.ts b/packages/platform-server/__tests__/socket.node_status.integration.test.ts index aa00be962..eacc465d5 100644 --- a/packages/platform-server/__tests__/socket.node_status.integration.test.ts +++ b/packages/platform-server/__tests__/socket.node_status.integration.test.ts @@ -1,6 +1,8 @@ import { describe, it, expect } from 'vitest'; import { FastifyAdapter } from '@nestjs/platform-fastify'; import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; +import type { ConfigService } from '../src/core/services/config.service'; +import type { AuthService } from '../src/auth/auth.service'; import Node from '../src/nodes/base/Node'; class DummyNode extends Node> { getPortConfig() { return { sourcePorts: { $self: { kind: 'instance' } } } as const; } } @@ -29,7 +31,9 @@ describe('Gateway node_status integration', () => { subscribeToThreadMetrics: () => () => {}, subscribeToThreadMetricsAncestors: () => () => {}, }; - const gateway = new GraphSocketGateway(runtimeStub, metricsStub as any, prismaStub as any, eventsBusStub as any); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authStub = { resolvePrincipalFromCookieHeader: async () => ({ userId: 'test-user' }) } as unknown as AuthService; + const gateway = new GraphSocketGateway(runtimeStub, metricsStub as any, prismaStub as any, eventsBusStub as any, configStub, authStub); gateway.init({ server: fastify.server }); const node = new DummyNode(); node.init({ nodeId: 'nX' }); diff --git a/packages/platform-server/__tests__/socket.realtime.integration.test.ts b/packages/platform-server/__tests__/socket.realtime.integration.test.ts index 9b5eaeef9..b74760854 100644 --- a/packages/platform-server/__tests__/socket.realtime.integration.test.ts +++ b/packages/platform-server/__tests__/socket.realtime.integration.test.ts @@ -7,6 +7,8 @@ import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; import type { LiveGraphRuntime } from '../src/graph-core/liveGraph.manager'; import type { ThreadsMetricsService } from '../src/agents/threads.metrics.service'; import type { PrismaService } from '../src/core/services/prisma.service'; +import type { ConfigService } from '../src/core/services/config.service'; +import type { AuthService } from '../src/auth/auth.service'; import { PrismaClient, ToolExecStatus } from '@prisma/client'; import { RunEventsService } from '../src/events/run-events.service'; import { EventsBusService } from '../src/events/events-bus.service'; @@ -15,6 +17,7 @@ import type { TemplateRegistry } from '../src/graph-core/templateRegistry'; import type { GraphRepository } from '../src/graph/graph.repository'; import { HumanMessage, AIMessage } from '@agyn/llm'; import { CallAgentLinkingService } from '../src/agents/call-agent-linking.service'; +import { UserService } from '../src/auth/user.service'; type MetricsPayload = { activity: 'working' | 'waiting' | 'idle'; remindersCount: number }; @@ -62,6 +65,12 @@ const createPrismaStub = () => }), }) as unknown as PrismaService; +const createConfigStub = (): ConfigService => ({ corsOrigins: [] } as unknown as ConfigService); +const createAuthStub = (): AuthService => + ({ + resolvePrincipalFromCookieHeader: async () => ({ userId: 'test-user' }), + }) as unknown as AuthService; + const createLinkingStub = () => ({ buildInitialMetadata: (params: { tool: 'call_agent' | 'call_engineer'; parentThreadId: string; childThreadId: string }) => ({ @@ -144,7 +153,14 @@ if (!shouldRunRealtimeTests) { await new Promise((resolve) => server.listen(0, resolve)); const { port } = server.address() as AddressInfo; const eventsBus = createEventsBusNoop(); - const gateway = new GraphSocketGateway(runtime, metricsDouble.service, prismaStub, eventsBus); + const gateway = new GraphSocketGateway( + runtime, + metricsDouble.service, + prismaStub, + eventsBus, + createConfigStub(), + createAuthStub(), + ); gateway.onModuleInit(); gateway.init({ server }); @@ -166,6 +182,7 @@ if (!shouldRunRealtimeTests) { createdAt: new Date(), parentId: null, channelNodeId: null, + ownerUserId: 'test-user', }); const createdPayload = await createdPromise; expect(createdPayload.thread.id).toBe(threadId); @@ -179,6 +196,7 @@ if (!shouldRunRealtimeTests) { createdAt: new Date(), parentId: null, channelNodeId: null, + ownerUserId: 'test-user', }); const updatedPayload = await updatedPromise; expect(updatedPayload.thread.summary).toBe('Updated summary'); @@ -203,7 +221,14 @@ if (!shouldRunRealtimeTests) { const prismaService = ({ getClient: () => prisma }) as PrismaService; const runEvents = new RunEventsService(prismaService); const eventsBus = new EventsBusService(runEvents); - const gateway = new GraphSocketGateway(runtime, metricsDouble.service, prismaService, eventsBus); + const gateway = new GraphSocketGateway( + runtime, + metricsDouble.service, + prismaService, + eventsBus, + createConfigStub(), + createAuthStub(), + ); gateway.onModuleInit(); const server = createServer(); @@ -222,6 +247,7 @@ if (!shouldRunRealtimeTests) { const templateRegistryStub = ({ getMeta: () => undefined }) as unknown as TemplateRegistry; const graphRepositoryStub = ({ get: async () => ({ nodes: [] }) }) as unknown as GraphRepository; + const userService = new UserService(prismaService); const agents = new AgentsPersistenceService( prismaService, metricsDouble.service, @@ -230,6 +256,7 @@ if (!shouldRunRealtimeTests) { runEvents, createLinkingStub(), eventsBus, + userService, ); const startResult = await agents.beginRunThread(thread.id, [HumanMessage.fromText('hello')]); @@ -267,7 +294,14 @@ if (!shouldRunRealtimeTests) { const prismaService = ({ getClient: () => prisma }) as PrismaService; const runEvents = new RunEventsService(prismaService); const eventsBus = new EventsBusService(runEvents); - const gateway = new GraphSocketGateway(runtime, metricsDouble.service, prismaService, eventsBus); + const gateway = new GraphSocketGateway( + runtime, + metricsDouble.service, + prismaService, + eventsBus, + createConfigStub(), + createAuthStub(), + ); gateway.onModuleInit(); const server = createServer(); @@ -277,6 +311,7 @@ if (!shouldRunRealtimeTests) { const templateRegistryStub = ({ getMeta: () => undefined }) as unknown as TemplateRegistry; const graphRepositoryStub = ({ get: async () => ({ nodes: [] }) }) as unknown as GraphRepository; + const userService = new UserService(prismaService); const agents = new AgentsPersistenceService( prismaService, metricsDouble.service, @@ -285,6 +320,7 @@ if (!shouldRunRealtimeTests) { runEvents, createLinkingStub(), eventsBus, + userService, ); const thread = await prisma.thread.create({ data: { alias: `thread-${randomUUID()}`, summary: 'timeline' } }); diff --git a/packages/platform-server/__tests__/sql.threads.metrics.queries.test.ts b/packages/platform-server/__tests__/sql.threads.metrics.queries.test.ts index 1b9e81b59..10e962e7a 100644 --- a/packages/platform-server/__tests__/sql.threads.metrics.queries.test.ts +++ b/packages/platform-server/__tests__/sql.threads.metrics.queries.test.ts @@ -2,6 +2,8 @@ import { describe, it, expect, vi } from 'vitest'; import { ThreadsMetricsService } from '../src/agents/threads.metrics.service'; import { GraphSocketGateway } from '../src/gateway/graph.socket.gateway'; import type { PrismaService } from '../src/core/services/prisma.service'; +import type { ConfigService } from '../src/core/services/config.service'; +import type { AuthService } from '../src/auth/auth.service'; describe('SQL: WITH RECURSIVE and UUID casts', () => { it('getThreadsMetrics uses WITH RECURSIVE and ::uuid[] and returns expected aggregation', async () => { @@ -64,7 +66,9 @@ describe('SQL: WITH RECURSIVE and UUID casts', () => { const metricsStub = { getThreadsMetrics: vi.fn(async () => ({})) }; const runtimeStub = { subscribe: () => () => {} } as any; const eventsBusStub = {} as any; - const gateway = new GraphSocketGateway(runtimeStub, metricsStub as any, prismaStub, eventsBusStub); + const configStub = { corsOrigins: [] } as unknown as ConfigService; + const authStub = { resolvePrincipalFromCookieHeader: vi.fn() } as unknown as AuthService; + const gateway = new GraphSocketGateway(runtimeStub, metricsStub as any, prismaStub, eventsBusStub, configStub, authStub); const scheduled: string[] = []; // Spy/override scheduleThreadMetrics to capture scheduled ids diff --git a/packages/platform-server/__tests__/startupRecovery.service.spec.ts b/packages/platform-server/__tests__/startupRecovery.service.spec.ts index 848f8ad6b..d4f38d8f8 100644 --- a/packages/platform-server/__tests__/startupRecovery.service.spec.ts +++ b/packages/platform-server/__tests__/startupRecovery.service.spec.ts @@ -20,7 +20,7 @@ class CaptureEventsBus { readonly runStatusChanges: Array<{ threadId: string; runId: string; status: RunStatus }> = []; readonly metricsScheduled: string[] = []; - emitRunStatusChanged(payload: { threadId: string; run: { id: string; status: RunStatus; createdAt: Date; updatedAt: Date } }): void { + emitRunStatusChanged(payload: { threadId: string; ownerUserId: string; run: { id: string; status: RunStatus; createdAt: Date; updatedAt: Date } }): void { this.runStatusChanges.push({ threadId: payload.threadId, runId: payload.run.id, status: payload.run.status }); } diff --git a/packages/platform-server/package.json b/packages/platform-server/package.json index 4bcc03331..8f05482dd 100644 --- a/packages/platform-server/package.json +++ b/packages/platform-server/package.json @@ -21,9 +21,9 @@ "prisma:studio": "prisma studio" }, "dependencies": { - "@agyn/shared": "workspace:*", "@agyn/json-schema-to-zod": "workspace:*", "@agyn/llm": "workspace:*", + "@agyn/shared": "workspace:*", "@fastify/cors": "^11.1.0", "@fastify/websocket": "^11.2.0", "@langchain/core": "1.0.1", @@ -43,6 +43,7 @@ "@types/json-schema": "^7.0.15", "class-transformer": "^0.5.1", "class-validator": "^0.14.1", + "cookie": "^1.1.1", "dockerode": "^4.0.8", "dotenv": "^17.2.2", "fast-glob": "^3.3.2", @@ -52,9 +53,10 @@ "mustache": "^4.2.0", "nestjs-pino": "^4.5.0", "node-fetch-native": "^1.6.7", - "picomatch": "^4.0.2", "openai": "^6.6.0", + "openid-client": "^6.8.1", "p-limit": "^3.1.0", + "picomatch": "^4.0.2", "pino": "^10.1.0", "reflect-metadata": "^0.2.2", "rxjs": "^7.8.1", @@ -75,8 +77,8 @@ "@types/json-schema": "^7.0.15", "@types/lodash-es": "^4.17.12", "@types/md5": "^2.3.5", - "@types/node": "^24.5.1", "@types/mustache": "^4.2.5", + "@types/node": "^24.5.1", "@types/semver": "^7.5.8", "@types/tar-stream": "^2.2.3", "@types/ws": "^8.5.11", diff --git a/packages/platform-server/prisma/migrations/20260128214121_auth_and_sessions/migration.sql b/packages/platform-server/prisma/migrations/20260128214121_auth_and_sessions/migration.sql new file mode 100644 index 000000000..9177c62cb --- /dev/null +++ b/packages/platform-server/prisma/migrations/20260128214121_auth_and_sessions/migration.sql @@ -0,0 +1,59 @@ +-- Add column nullable first so existing rows can be backfilled +ALTER TABLE "Thread" ADD COLUMN "ownerUserId" UUID; + +-- CreateTable +CREATE TABLE "User" ( + "id" UUID NOT NULL, + "email" TEXT, + "name" TEXT, + "oidcIssuer" TEXT, + "oidcSubject" TEXT, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "User_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "Session" ( + "id" UUID NOT NULL, + "userId" UUID NOT NULL, + "expiresAt" TIMESTAMP(3) NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "Session_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "User_email_key" ON "User"("email"); + +-- CreateIndex +CREATE UNIQUE INDEX "User_oidcIssuer_oidcSubject_key" ON "User"("oidcIssuer", "oidcSubject"); + +-- CreateIndex +CREATE INDEX "Session_userId_idx" ON "Session"("userId"); + +-- CreateIndex +CREATE INDEX "Thread_ownerUserId_idx" ON "Thread"("ownerUserId"); + +-- Seed deterministic default principal for existing data +INSERT INTO "User" ("id", "email", "name", "createdAt", "updatedAt") +VALUES ('00000000-0000-0000-0000-000000000001', 'default@local', 'Default User', NOW(), NOW()) +ON CONFLICT ("id") DO NOTHING; + +-- Backfill existing threads to default owner before enforcing NOT NULL +UPDATE "Thread" +SET "ownerUserId" = COALESCE("ownerUserId", '00000000-0000-0000-0000-000000000001'); + +-- Enforce NOT NULL once data is migrated +ALTER TABLE "Thread" ALTER COLUMN "ownerUserId" SET NOT NULL; + +-- AddForeignKey +ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "Thread" ADD CONSTRAINT "Thread_ownerUserId_fkey" FOREIGN KEY ("ownerUserId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "Reminder" ADD CONSTRAINT "Reminder_threadId_fkey" FOREIGN KEY ("threadId") REFERENCES "Thread"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/platform-server/prisma/schema.prisma b/packages/platform-server/prisma/schema.prisma index 57c535ee4..735748292 100644 --- a/packages/platform-server/prisma/schema.prisma +++ b/packages/platform-server/prisma/schema.prisma @@ -50,6 +50,32 @@ model UserProfile { updatedAt DateTime @updatedAt } +model User { + id String @id @default(uuid()) @db.Uuid + email String? @unique + name String? + oidcIssuer String? + oidcSubject String? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + threads Thread[] + sessions Session[] + + @@unique([oidcIssuer, oidcSubject]) +} + +model Session { + id String @id @default(uuid()) @db.Uuid + userId String @db.Uuid + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + expiresAt DateTime + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@index([userId]) +} + enum MemoryScope { global perThread @@ -107,6 +133,8 @@ model Thread { alias String @unique summary String? status ThreadStatus @default(open) + ownerUserId String @db.Uuid + owner User @relation(fields: [ownerUserId], references: [id], onDelete: Restrict) channelNodeId String? @map("channel_node_id") @db.Uuid assignedAgentNodeId String? @map("assigned_agent_node_id") @db.Uuid // Channel descriptor for replying to origin channel @@ -117,12 +145,14 @@ model Thread { runs Run[] containers Container[] runEvents RunEvent[] + reminders Reminder[] // Index parentId to accelerate parent-child lookups @@index([parentId]) // Filter by thread status (open/closed) @@index([status]) @@index([assignedAgentNodeId]) + @@index([ownerUserId]) } enum ThreadStatus { @@ -512,6 +542,7 @@ model ContainerEvent { model Reminder { id String @id @default(uuid()) @db.Uuid threadId String @db.Uuid + thread Thread @relation(fields: [threadId], references: [id], onDelete: Cascade) note String at DateTime createdAt DateTime @default(now()) diff --git a/packages/platform-server/src/agents/agents.persistence.service.ts b/packages/platform-server/src/agents/agents.persistence.service.ts index 59713404f..c562e350b 100644 --- a/packages/platform-server/src/agents/agents.persistence.service.ts +++ b/packages/platform-server/src/agents/agents.persistence.service.ts @@ -20,6 +20,7 @@ import { RunEventsService } from '../events/run-events.service'; import { EventsBusService } from '../events/events-bus.service'; import { CallAgentLinkingService } from './call-agent-linking.service'; import { ThreadsMetricsService, type ThreadMetrics } from './threads.metrics.service'; +import { UserService } from '../auth/user.service'; export type RunStartResult = { runId: string }; @@ -63,6 +64,7 @@ export class AgentsPersistenceService { @Inject(RunEventsService) private readonly runEvents: RunEventsService, @Inject(CallAgentLinkingService) private readonly callAgentLinking: CallAgentLinkingService, @Inject(EventsBusService) private readonly eventsBus: EventsBusService, + @Inject(UserService) private readonly users: UserService, ) {} private format(context?: Record): string { @@ -84,6 +86,21 @@ export class AgentsPersistenceService { return (summary ?? '').trim().slice(0, 256); } + private async resolveOwnerId(ownerUserId?: string): Promise { + if (ownerUserId) return ownerUserId; + const user = await this.users.ensureDefaultUser(); + return user.id; + } + + private async getThreadOwnerId(threadId: string, tx?: Prisma.TransactionClient): Promise { + const prisma = tx ?? this.prisma; + const thread = await prisma.thread.findUnique({ where: { id: threadId }, select: { ownerUserId: true } }); + if (!thread) { + throw new Error('thread_not_found'); + } + return thread.ownerUserId; + } + async ensureThreadModel(threadId: string, model: string): Promise { if (!model || model.trim().length === 0) { throw new Error('agent_model_required'); @@ -127,6 +144,7 @@ export class AgentsPersistenceService { parentId: true, channelNodeId: true, assignedAgentNodeId: true, + ownerUserId: true, }, }); }); @@ -139,6 +157,7 @@ export class AgentsPersistenceService { summary: updated.summary ?? null, status: updated.status, createdAt: updated.createdAt, + ownerUserId: updated.ownerUserId, parentId: updated.parentId ?? null, channelNodeId: updated.channelNodeId ?? null, assignedAgentNodeId: updated.assignedAgentNodeId ?? null, @@ -152,15 +171,17 @@ export class AgentsPersistenceService { _source: string, alias: string, summary: string, - options?: { channelNodeId?: string }, + options?: { channelNodeId?: string; ownerUserId?: string }, ): Promise { const existing = await this.prisma.thread.findUnique({ where: { alias }, select: { id: true } }); if (existing) return existing.id; const sanitized = this.sanitizeSummary(summary); + const ownerUserId = await this.resolveOwnerId(options?.ownerUserId); const created = await this.prisma.thread.create({ data: { alias, summary: sanitized, + ownerUserId, ...(options?.channelNodeId ? { channelNodeId: options.channelNodeId } : {}), }, }); @@ -170,6 +191,7 @@ export class AgentsPersistenceService { summary: created.summary ?? null, status: created.status, createdAt: created.createdAt, + ownerUserId, parentId: created.parentId ?? null, channelNodeId: created.channelNodeId ?? null, assignedAgentNodeId: created.assignedAgentNodeId ?? null, @@ -191,13 +213,28 @@ export class AgentsPersistenceService { return; } const channelJson = toPrismaJsonValue(parsed.data); - const updated = await this.prisma.thread.update({ where: { id: threadId }, data: { channel: channelJson } }); + const updated = await this.prisma.thread.update({ + where: { id: threadId }, + data: { channel: channelJson }, + select: { + id: true, + alias: true, + summary: true, + status: true, + createdAt: true, + parentId: true, + channelNodeId: true, + assignedAgentNodeId: true, + ownerUserId: true, + }, + }); this.eventsBus.emitThreadUpdated({ id: updated.id, alias: updated.alias, summary: updated.summary ?? null, status: updated.status, createdAt: updated.createdAt, + ownerUserId: updated.ownerUserId, parentId: updated.parentId ?? null, channelNodeId: updated.channelNodeId ?? null, assignedAgentNodeId: updated.assignedAgentNodeId ?? null, @@ -212,14 +249,21 @@ export class AgentsPersistenceService { const composed = `${source}:${parentThreadId}:${alias}`; const existing = await this.prisma.thread.findUnique({ where: { alias: composed } }); if (existing) return existing.id; + const parent = await this.prisma.thread.findUnique({ where: { id: parentThreadId }, select: { id: true, ownerUserId: true } }); + if (!parent) { + throw new ThreadParentNotFoundError(); + } const sanitized = this.sanitizeSummary(summary); - const created = await this.prisma.thread.create({ data: { alias: composed, parentId: parentThreadId, summary: sanitized } }); + const created = await this.prisma.thread.create({ + data: { alias: composed, parentId: parentThreadId, summary: sanitized, ownerUserId: parent.ownerUserId }, + }); this.eventsBus.emitThreadCreated({ id: created.id, alias: created.alias, summary: created.summary ?? null, status: created.status, createdAt: created.createdAt, + ownerUserId: created.ownerUserId, parentId: created.parentId ?? null, channelNodeId: created.channelNodeId ?? null, assignedAgentNodeId: created.assignedAgentNodeId ?? null, @@ -232,6 +276,7 @@ export class AgentsPersistenceService { alias: string; text: string; agentNodeId: string; + ownerUserId: string; parentId?: string | null; }): Promise<{ id: string; @@ -242,6 +287,7 @@ export class AgentsPersistenceService { parentId: string | null; channelNodeId: string | null; assignedAgentNodeId: string | null; + ownerUserId: string; }> { const alias = params.alias.trim(); if (alias.length === 0) { @@ -253,13 +299,17 @@ export class AgentsPersistenceService { } const parentId = params.parentId ?? null; const sanitizedSummary = this.sanitizeSummary(params.text); + const ownerUserId = await this.resolveOwnerId(params.ownerUserId); const created = await this.prisma.$transaction(async (tx: Prisma.TransactionClient) => { if (parentId) { - const parent = await tx.thread.findUnique({ where: { id: parentId }, select: { id: true } }); + const parent = await tx.thread.findUnique({ where: { id: parentId }, select: { id: true, ownerUserId: true } }); if (!parent) { throw new ThreadParentNotFoundError(); } + if (parent.ownerUserId !== ownerUserId) { + throw new Error('thread_parent_owner_mismatch'); + } } return tx.thread.create({ @@ -268,6 +318,7 @@ export class AgentsPersistenceService { summary: sanitizedSummary, parentId, assignedAgentNodeId: agentNodeId, + ownerUserId, }, }); }); @@ -278,6 +329,7 @@ export class AgentsPersistenceService { summary: created.summary ?? null, status: created.status, createdAt: created.createdAt, + ownerUserId: created.ownerUserId, parentId: created.parentId ?? null, channelNodeId: created.channelNodeId ?? null, assignedAgentNodeId: created.assignedAgentNodeId ?? null, @@ -293,6 +345,7 @@ export class AgentsPersistenceService { summary: created.summary ?? null, status: created.status, createdAt: created.createdAt, + ownerUserId: created.ownerUserId, parentId: created.parentId ?? null, channelNodeId: created.channelNodeId ?? null, assignedAgentNodeId: created.assignedAgentNodeId ?? null, @@ -318,6 +371,7 @@ export class AgentsPersistenceService { parentId: true, channelNodeId: true, assignedAgentNodeId: true, + ownerUserId: true, }, }); if (!updated) return; @@ -327,6 +381,7 @@ export class AgentsPersistenceService { summary: updated.summary ?? null, status: updated.status, createdAt: updated.createdAt, + ownerUserId: updated.ownerUserId, parentId: updated.parentId ?? null, channelNodeId: updated.channelNodeId ?? null, assignedAgentNodeId: updated.assignedAgentNodeId ?? null, @@ -349,6 +404,7 @@ export class AgentsPersistenceService { inputMessages: Array, agentNodeId?: string, ): Promise { + const ownerUserId = await this.getThreadOwnerId(threadId); const { runId, createdMessages, eventIds, patchedEventIds } = await this.prisma.$transaction(async (tx: Prisma.TransactionClient) => { // Begin run and persist messages const run = await tx.run.create({ data: { threadId, status: 'running' as RunStatus } }); @@ -394,11 +450,13 @@ export class AgentsPersistenceService { }); this.eventsBus.emitRunStatusChanged({ threadId, + ownerUserId, run: { id: runId, status: 'running' as RunStatus, createdAt: new Date(), updatedAt: new Date() }, }); for (const m of createdMessages) { this.eventsBus.emitMessageCreated({ threadId, + ownerUserId, message: { id: m.id, kind: m.kind, text: m.text, source: m.source as Prisma.JsonValue, createdAt: m.createdAt, runId }, }); } @@ -477,18 +535,22 @@ export class AgentsPersistenceService { if (!threadId) return { messageIds: [] }; - for (const m of createdMessages) { - this.eventsBus.emitMessageCreated({ - threadId, - message: { - id: m.id, - kind: m.kind, - text: m.text, - source: m.source as Prisma.JsonValue, - createdAt: m.createdAt, - runId, - }, - }); + if (threadId) { + const ownerUserId = await this.getThreadOwnerId(threadId); + for (const m of createdMessages) { + this.eventsBus.emitMessageCreated({ + threadId, + ownerUserId, + message: { + id: m.id, + kind: m.kind, + text: m.text, + source: m.source as Prisma.JsonValue, + createdAt: m.createdAt, + runId, + }, + }); + } } await Promise.all(eventIds.map((id) => this.eventsBus.publishEvent(id, 'append'))); @@ -505,6 +567,7 @@ export class AgentsPersistenceService { }): Promise<{ messageId: string }> { const normalizedThreadId = params.threadId?.trim(); if (!normalizedThreadId) throw new Error('thread_id_required'); + const ownerUserId = await this.getThreadOwnerId(normalizedThreadId); const assistant = AIMessage.fromText(params.text ?? ''); const sourcePayload = toPrismaJsonValue(assistant.toPlain()); @@ -556,6 +619,7 @@ export class AgentsPersistenceService { this.eventsBus.emitMessageCreated({ threadId: normalizedThreadId, + ownerUserId, message: { id: message.id, kind: 'assistant' as MessageKind, @@ -616,13 +680,19 @@ export class AgentsPersistenceService { return updated; }); const threadId = run.threadId; + const ownerUserId = await this.getThreadOwnerId(threadId); for (const m of createdMessages) { this.eventsBus.emitMessageCreated({ threadId, + ownerUserId, message: { id: m.id, kind: m.kind, text: m.text, source: m.source as Prisma.JsonValue, createdAt: m.createdAt, runId }, }); } - this.eventsBus.emitRunStatusChanged({ threadId, run: { id: runId, status, createdAt: run.createdAt, updatedAt: run.updatedAt } }); + this.eventsBus.emitRunStatusChanged({ + threadId, + ownerUserId, + run: { id: runId, status, createdAt: run.createdAt, updatedAt: run.updatedAt }, + }); this.eventsBus.emitThreadMetrics({ threadId }); await Promise.all(eventIds.map((id) => this.eventsBus.publishEvent(id, 'append'))); await Promise.all(patchedEventIds.map((id) => this.eventsBus.publishEvent(id, 'update'))); @@ -636,6 +706,7 @@ export class AgentsPersistenceService { includeAgentTitles: boolean; childrenStatus: 'open' | 'closed' | 'all'; perParentChildrenLimit: number; + ownerUserId?: string; }): Promise { const limit = Math.min(Math.max(opts.limit, 1), 1000); const depth = Math.min(Math.max(opts.depth, 0), 2) as 0 | 1 | 2; @@ -645,6 +716,7 @@ export class AgentsPersistenceService { const rootWhere: Prisma.ThreadWhereInput = { parentId: null }; if (status !== 'all') rootWhere.status = status as ThreadStatus; + if (opts.ownerUserId) rootWhere.ownerUserId = opts.ownerUserId; const rootRows = await this.prisma.thread.findMany({ where: rootWhere, @@ -704,6 +776,7 @@ export class AgentsPersistenceService { if (depth >= 1) { const childWhere: Prisma.ThreadWhereInput = { parentId: { in: rootIds } }; if (childrenStatus !== 'all') childWhere.status = childrenStatus as ThreadStatus; + if (opts.ownerUserId) childWhere.ownerUserId = opts.ownerUserId; const childRows = await this.prisma.thread.findMany({ where: childWhere, select: { id: true, alias: true, summary: true, status: true, createdAt: true, parentId: true }, @@ -714,6 +787,7 @@ export class AgentsPersistenceService { if (depth >= 2 && childIds.length > 0) { const grandchildWhere: Prisma.ThreadWhereInput = { parentId: { in: childIds } }; if (childrenStatus !== 'all') grandchildWhere.status = childrenStatus as ThreadStatus; + if (opts.ownerUserId) grandchildWhere.ownerUserId = opts.ownerUserId; const grandchildRows = await this.prisma.thread.findMany({ where: grandchildWhere, select: { id: true, alias: true, summary: true, status: true, createdAt: true, parentId: true }, @@ -793,13 +867,19 @@ export class AgentsPersistenceService { return rootNodes.map(clone); } - async listThreads(opts?: { rootsOnly?: boolean; status?: 'open' | 'closed' | 'all'; limit?: number }): Promise> { + async listThreads(opts?: { + rootsOnly?: boolean; + status?: 'open' | 'closed' | 'all'; + limit?: number; + ownerUserId?: string; + }): Promise> { const rootsOnly = opts?.rootsOnly ?? false; const status = opts?.status ?? 'all'; const limit = opts?.limit ?? 100; const where: Prisma.ThreadWhereInput = {}; if (rootsOnly) where.parentId = null; if (status !== 'all') where.status = status as ThreadStatus; + if (opts?.ownerUserId) where.ownerUserId = opts.ownerUserId; return this.prisma.thread.findMany({ where, orderBy: { createdAt: 'desc' }, @@ -810,7 +890,7 @@ export class AgentsPersistenceService { async getThreadById( threadId: string, - opts?: { includeMetrics?: boolean; includeAgentTitles?: boolean }, + opts?: { includeMetrics?: boolean; includeAgentTitles?: boolean; ownerUserId?: string }, ): Promise< | ({ id: string; @@ -820,6 +900,7 @@ export class AgentsPersistenceService { createdAt: Date; parentId: string | null; assignedAgentNodeId: string | null; + ownerUserId: string; metrics?: ThreadMetrics; agentTitle?: string; agentRole?: string; @@ -827,8 +908,8 @@ export class AgentsPersistenceService { }) | null > { - const thread = await this.prisma.thread.findUnique({ - where: { id: threadId }, + const thread = await this.prisma.thread.findFirst({ + where: { id: threadId, ...(opts?.ownerUserId ? { ownerUserId: opts.ownerUserId } : {}) }, select: { id: true, alias: true, @@ -837,6 +918,7 @@ export class AgentsPersistenceService { createdAt: true, parentId: true, assignedAgentNodeId: true, + ownerUserId: true, }, }); if (!thread) return null; @@ -852,6 +934,7 @@ export class AgentsPersistenceService { createdAt: Date; parentId: string | null; assignedAgentNodeId: string | null; + ownerUserId: string; metrics?: ThreadMetrics; agentTitle?: string; agentRole?: string; @@ -860,6 +943,7 @@ export class AgentsPersistenceService { ...thread, parentId: thread.parentId ?? null, assignedAgentNodeId: thread.assignedAgentNodeId ?? null, + ownerUserId: thread.ownerUserId, }; const defaultMetrics: ThreadMetrics = { remindersCount: 0, containersCount: 0, activity: 'idle', runsCount: 0 }; @@ -905,22 +989,34 @@ export class AgentsPersistenceService { return state?.nodeId ?? null; } - async listChildren(parentId: string, status: 'open' | 'closed' | 'all' = 'all'): Promise> { + async listChildren( + parentId: string, + status: 'open' | 'closed' | 'all' = 'all', + ownerUserId?: string, + ): Promise> { const where: Prisma.ThreadWhereInput = { parentId }; if (status !== 'all') where.status = status as ThreadStatus; + if (ownerUserId) where.ownerUserId = ownerUserId; return this.prisma.thread.findMany({ where, orderBy: { createdAt: 'desc' }, select: { id: true, alias: true, summary: true, status: true, createdAt: true, parentId: true } }); } async updateThread( threadId: string, data: { summary?: string | null; status?: ThreadStatus }, + scope?: { ownerUserId?: string }, ): Promise<{ previousStatus: ThreadStatus; status: ThreadStatus }> { const patch: Prisma.ThreadUpdateInput = {}; if (data.summary !== undefined) patch.summary = data.summary; if (data.status !== undefined) patch.status = data.status; const result = await this.prisma.$transaction(async (tx: Prisma.TransactionClient) => { - const current = await tx.thread.findUnique({ where: { id: threadId }, select: { status: true } }); + const current = await tx.thread.findFirst({ + where: { id: threadId, ...(scope?.ownerUserId ? { ownerUserId: scope.ownerUserId } : {}) }, + select: { status: true }, + }); + if (!current) { + throw new Error('thread_not_found'); + } const updated = await tx.thread.update({ where: { id: threadId }, data: patch }); return { updated, previousStatus: current?.status ?? updated.status }; }); @@ -932,6 +1028,7 @@ export class AgentsPersistenceService { summary: updated.summary ?? null, status: updated.status, createdAt: updated.createdAt, + ownerUserId: updated.ownerUserId, parentId: updated.parentId ?? null, channelNodeId: updated.channelNodeId ?? null, assignedAgentNodeId: updated.assignedAgentNodeId ?? null, @@ -1000,11 +1097,19 @@ export class AgentsPersistenceService { }); } - async getRunById(runId: string): Promise<{ id: string; threadId: string; status: RunStatus } | null> { - return this.prisma.run.findUnique({ + async getRunById( + runId: string, + scope?: { ownerUserId?: string }, + ): Promise<{ id: string; threadId: string; status: RunStatus } | null> { + const run = await this.prisma.run.findUnique({ where: { id: runId }, - select: { id: true, threadId: true, status: true }, + select: { id: true, threadId: true, status: true, thread: { select: { ownerUserId: true } } }, }); + if (!run) return null; + if (scope?.ownerUserId && run.thread.ownerUserId !== scope.ownerUserId) { + return null; + } + return { id: run.id, threadId: run.threadId, status: run.status }; } async listRunMessages(runId: string, type: RunMessageType): Promise> { @@ -1019,6 +1124,7 @@ export class AgentsPersistenceService { filter: 'active' | 'completed' | 'cancelled' | 'all' = 'active', take: number = 100, threadId?: string, + ownerUserId?: string, ): Promise> { const limit = Number.isFinite(take) ? Math.min(1000, Math.max(1, Math.trunc(take))) : 100; const where: Prisma.ReminderWhereInput = {}; @@ -1031,6 +1137,7 @@ export class AgentsPersistenceService { where.NOT = { cancelledAt: null }; } if (threadId) where.threadId = threadId; + if (ownerUserId) where.thread = { ownerUserId }; try { return await this.prisma.reminder.findMany({ @@ -1059,6 +1166,7 @@ export class AgentsPersistenceService { sort = 'latest', order = 'desc', threadId, + ownerUserId, }: { filter?: 'all' | 'active' | 'completed' | 'cancelled'; page?: number; @@ -1066,6 +1174,7 @@ export class AgentsPersistenceService { sort?: 'latest' | 'createdAt' | 'at'; order?: 'asc' | 'desc'; threadId?: string; + ownerUserId?: string; }): Promise<{ items: Array<{ id: string; threadId: string; note: string; at: Date; createdAt: Date; completedAt: Date | null; cancelledAt: Date | null }>; page: number; @@ -1086,9 +1195,11 @@ export class AgentsPersistenceService { const filterKey: 'all' | 'active' | 'completed' | 'cancelled' = filter ?? 'all'; const skip = (normalizedPage - 1) * normalizedPageSize; - const { where, clauses } = this.buildReminderFilter(filterKey, threadId); + const { where, clauses } = this.buildReminderFilter(filterKey, threadId, ownerUserId); const whereForQuery = Object.keys(where).length === 0 ? undefined : where; - const countsBaseWhere: Prisma.ReminderWhereInput = threadId ? { threadId } : {}; + const countsBaseWhere: Prisma.ReminderWhereInput = {}; + if (threadId) countsBaseWhere.threadId = threadId; + if (ownerUserId) countsBaseWhere.thread = { ownerUserId }; try { return await this.prisma.$transaction(async (tx) => { @@ -1115,8 +1226,9 @@ export class AgentsPersistenceService { }), ]); + const useLatestAllOptimization = sortKey === 'latest' && filterKey === 'all' && !ownerUserId; const items = - sortKey === 'latest' && filterKey === 'all' + useLatestAllOptimization ? await this.fetchRemindersLatestAll(tx, clauses, skip, normalizedPageSize, sortOrder) : await tx.reminder.findMany({ where: whereForQuery, @@ -1159,6 +1271,7 @@ export class AgentsPersistenceService { sort: sortKey, order: sortOrder, threadId, + ownerUserId, error: this.errorInfo(error), })}`, ); @@ -1169,6 +1282,7 @@ export class AgentsPersistenceService { private buildReminderFilter( filter: 'all' | 'active' | 'completed' | 'cancelled', threadId?: string, + ownerUserId?: string, ): { where: Prisma.ReminderWhereInput; clauses: Prisma.Sql[] } { const where: Prisma.ReminderWhereInput = {}; const clauses: Prisma.Sql[] = []; @@ -1178,6 +1292,10 @@ export class AgentsPersistenceService { clauses.push(Prisma.sql`"threadId" = ${threadId}`); } + if (ownerUserId) { + where.thread = { ownerUserId }; + } + switch (filter) { case 'active': where.completedAt = null; diff --git a/packages/platform-server/src/agents/reminders.controller.ts b/packages/platform-server/src/agents/reminders.controller.ts index 0550e33cf..d12bfc0e2 100644 --- a/packages/platform-server/src/agents/reminders.controller.ts +++ b/packages/platform-server/src/agents/reminders.controller.ts @@ -1,8 +1,10 @@ -import { Controller, Get, Inject, NotFoundException, Param, Post, Query } from '@nestjs/common'; +import { Controller, Get, Inject, NotFoundException, Param, Post, Query, UnauthorizedException } from '@nestjs/common'; import { Type } from 'class-transformer'; import { IsIn, IsInt, Min, Max, IsOptional, IsUUID } from 'class-validator'; import { AgentsPersistenceService } from './agents.persistence.service'; import { RemindersService } from './reminders.service'; +import { CurrentPrincipal } from '../auth/principal.decorator'; +import type { Principal } from '../auth/auth.types'; export class ListRemindersQueryDto { @IsOptional() @@ -49,8 +51,23 @@ export class AgentsRemindersController { @Inject(RemindersService) private readonly reminders: RemindersService, ) {} + private requirePrincipal(principal: Principal | null): Principal { + if (!principal) { + throw new UnauthorizedException({ error: 'unauthorized' }); + } + return principal; + } + @Get('reminders') - async listReminders(@Query() query: ListRemindersQueryDto) { + async listReminders(@Query() query: ListRemindersQueryDto, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + if (query.threadId) { + const thread = await this.persistence.getThreadById(query.threadId, { ownerUserId }); + if (!thread) { + throw new NotFoundException({ error: 'thread_not_found' }); + } + } const wantsPagination = query.page !== undefined || query.pageSize !== undefined || @@ -65,19 +82,22 @@ export class AgentsRemindersController { sort: query.sort ?? 'latest', order: query.order ?? 'desc', threadId: query.threadId, + ownerUserId, }); return result; } const filter = query.filter ?? 'active'; const take = query.take ?? 100; - const items = await this.persistence.listReminders(filter, take, query.threadId); + const items = await this.persistence.listReminders(filter, take, query.threadId, ownerUserId); return { items }; } @Post('reminders/:reminderId/cancel') - async cancelReminder(@Param('reminderId') reminderId: string) { - const result = await this.reminders.cancelReminder({ reminderId, emitMetrics: true }); + async cancelReminder(@Param('reminderId') reminderId: string, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + const result = await this.reminders.cancelReminder({ reminderId, emitMetrics: true, ownerUserId }); if (!result) { throw new NotFoundException({ error: 'reminder_not_found' }); } diff --git a/packages/platform-server/src/agents/reminders.service.ts b/packages/platform-server/src/agents/reminders.service.ts index 983292689..ece5f2a83 100644 --- a/packages/platform-server/src/agents/reminders.service.ts +++ b/packages/platform-server/src/agents/reminders.service.ts @@ -18,6 +18,7 @@ interface CancelReminderOptions { reminderId: string; prismaOverride?: PrismaExecutor; emitMetrics?: boolean; + ownerUserId?: string; } @Injectable() @@ -89,7 +90,7 @@ export class RemindersService { return { cancelledDb, clearedRuntime }; } - async cancelReminder({ reminderId, prismaOverride, emitMetrics }: CancelReminderOptions): Promise< + async cancelReminder({ reminderId, prismaOverride, emitMetrics, ownerUserId }: CancelReminderOptions): Promise< | { threadId: string; cancelledDb: boolean; @@ -101,12 +102,16 @@ export class RemindersService { const reminder = await prisma.reminder.findUnique({ where: { id: reminderId }, - select: { id: true, threadId: true, completedAt: true, cancelledAt: true }, + select: { id: true, threadId: true, completedAt: true, cancelledAt: true, thread: { select: { ownerUserId: true } } }, }); if (!reminder) { return null; } + if (ownerUserId && reminder.thread?.ownerUserId !== ownerUserId) { + return null; + } + const threadId = reminder.threadId ?? null; let cancelledDb = false; diff --git a/packages/platform-server/src/agents/threads.controller.ts b/packages/platform-server/src/agents/threads.controller.ts index ec595dedf..c24ab4efe 100644 --- a/packages/platform-server/src/agents/threads.controller.ts +++ b/packages/platform-server/src/agents/threads.controller.ts @@ -16,6 +16,7 @@ import { Post, Query, ServiceUnavailableException, + UnauthorizedException, } from '@nestjs/common'; import { IsBooleanString, IsIn, IsInt, IsOptional, IsString, IsISO8601, Max, Min, ValidateIf } from 'class-validator'; import { AgentsPersistenceService } from './agents.persistence.service'; @@ -32,6 +33,8 @@ import { TemplateRegistry } from '../graph-core/templateRegistry'; import { hasQueueManagementCapability, hasQueuedPreviewCapability, isAgentLiveNode, isAgentRuntimeInstance } from './agent-node.utils'; import { randomUUID } from 'node:crypto'; import { ThreadParentNotFoundError } from './agents.persistence.service'; +import { CurrentPrincipal } from '../auth/principal.decorator'; +import type { Principal } from '../auth/auth.types'; // Avoid runtime import of Prisma in tests; enumerate allowed values export const RunMessageTypeValues: ReadonlyArray = ['input', 'injected', 'output']; @@ -227,9 +230,41 @@ export class AgentsThreadsController { @Inject(TemplateRegistry) private readonly templateRegistry: TemplateRegistry, ) {} + private requirePrincipal(principal: Principal | null): Principal { + if (!principal) { + throw new UnauthorizedException({ error: 'unauthorized' }); + } + return principal; + } + + private async getThreadOrThrow( + threadId: string, + ownerUserId: string, + opts?: { includeMetrics?: boolean; includeAgentTitles?: boolean }, + ) { + const thread = await this.persistence.getThreadById(threadId, { ...opts, ownerUserId }); + if (!thread) { + throw new NotFoundException({ error: 'thread_not_found' }); + } + return thread; + } + + private async getRunOrThrow(runId: string, ownerUserId: string) { + const run = await this.persistence.getRunById(runId, { ownerUserId }); + if (!run) { + throw new NotFoundException('run_not_found'); + } + return run; + } + @Post('threads') @HttpCode(201) - async createThread(@Body() body: CreateThreadBody | null | undefined): Promise<{ id: string }> { + async createThread( + @Body() body: CreateThreadBody | null | undefined, + @CurrentPrincipal() principal: Principal | null, + ): Promise<{ id: string }> { + const currentPrincipal = this.requirePrincipal(principal ?? null); + const ownerUserId = currentPrincipal.userId; const textValue = typeof body?.text === 'string' ? body.text : ''; const text = textValue.trim(); if (text.length === 0 || text.length > AgentsThreadsController.MAX_MESSAGE_LENGTH) { @@ -246,6 +281,9 @@ export class AgentsThreadsController { const alias = aliasCandidate.length > 0 ? aliasCandidate : `ui:${randomUUID()}`; const parentIdCandidate = typeof body?.parentId === 'string' ? body.parentId.trim() : ''; const parentId = parentIdCandidate.length > 0 ? parentIdCandidate : null; + if (parentId) { + await this.getThreadOrThrow(parentId, ownerUserId); + } const liveNodes = this.runtime.getNodes(); const agentNodes = liveNodes.filter((node) => isAgentLiveNode(node, this.templateRegistry)); @@ -272,6 +310,7 @@ export class AgentsThreadsController { alias, text, agentNodeId, + ownerUserId, parentId, }); threadId = created.id; @@ -279,6 +318,9 @@ export class AgentsThreadsController { if (error instanceof ThreadParentNotFoundError || (error instanceof Error && error.message === 'parent_not_found')) { throw new NotFoundException({ error: 'parent_not_found' }); } + if (error instanceof Error && error.message === 'thread_parent_owner_mismatch') { + throw new NotFoundException({ error: 'parent_not_found' }); + } if (error instanceof Error && (error.message === 'thread_alias_required' || error.message === 'agent_node_id_required')) { throw new BadRequestException({ error: 'bad_message_payload' }); } @@ -303,11 +345,13 @@ export class AgentsThreadsController { } @Get('threads') - async listThreads(@Query() query: ListThreadsQueryDto) { + async listThreads(@Query() query: ListThreadsQueryDto, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal ?? null); + const ownerUserId = currentPrincipal.userId; const rootsOnly = (query.rootsOnly ?? 'false') === 'true'; const status = query.status ?? 'all'; - const limit = Number(query.limit) ?? 100; - const threads = await this.persistence.listThreads({ rootsOnly, status, limit }); + const limit = typeof query.limit === 'number' ? query.limit : 100; + const threads = await this.persistence.listThreads({ rootsOnly, status, limit, ownerUserId }); const includeMetrics = (query.includeMetrics ?? 'false') === 'true'; const includeAgentTitles = (query.includeAgentTitles ?? 'false') === 'true'; const ids = threads.map((t) => t.id); @@ -335,7 +379,9 @@ export class AgentsThreadsController { } @Get('threads/tree') - async listThreadsTree(@Query() query: ListThreadsTreeQueryDto) { + async listThreadsTree(@Query() query: ListThreadsTreeQueryDto, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal ?? null); + const ownerUserId = currentPrincipal.userId; const status = query.status ?? 'all'; const limit = query.limit ?? 50; const depth = (query.depth ?? 2) as 0 | 1 | 2; @@ -351,13 +397,21 @@ export class AgentsThreadsController { includeAgentTitles, childrenStatus, perParentChildrenLimit, + ownerUserId, }); return { items }; } @Get('threads/:threadId/children') - async listChildren(@Param('threadId') threadId: string, @Query() query: ListChildrenQueryDto) { - const items = await this.persistence.listChildren(threadId, query.status ?? 'all'); + async listChildren( + @Param('threadId') threadId: string, + @Query() query: ListChildrenQueryDto, + @CurrentPrincipal() principal: Principal | null, + ) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + await this.getThreadOrThrow(threadId, ownerUserId); + const items = await this.persistence.listChildren(threadId, query.status ?? 'all', ownerUserId); const includeMetrics = (query.includeMetrics ?? 'false') === 'true'; const includeAgentTitles = (query.includeAgentTitles ?? 'false') === 'true'; const ids = items.map((t) => t.id); @@ -386,11 +440,16 @@ export class AgentsThreadsController { } @Get('threads/:threadId') - async getThread(@Param('threadId') threadId: string, @Query() query: GetThreadQueryDto) { + async getThread( + @Param('threadId') threadId: string, + @Query() query: GetThreadQueryDto, + @CurrentPrincipal() principal: Principal | null, + ) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; const includeMetrics = (query.includeMetrics ?? 'false') === 'true'; const includeAgentTitles = (query.includeAgentTitles ?? 'false') === 'true'; - const thread = await this.persistence.getThreadById(threadId, { includeMetrics, includeAgentTitles }); - if (!thread) throw new NotFoundException('thread_not_found'); + const thread = await this.getThreadOrThrow(threadId, ownerUserId, { includeMetrics, includeAgentTitles }); if (!includeMetrics && !includeAgentTitles) return thread; const defaultMetrics: ThreadMetrics = { remindersCount: 0, containersCount: 0, activity: 'idle', runsCount: 0 }; const fallbackTitle = '(unknown agent)'; @@ -404,15 +463,19 @@ export class AgentsThreadsController { } @Get('threads/:threadId/runs') - async listRuns(@Param('threadId') threadId: string) { + async listRuns(@Param('threadId') threadId: string, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + await this.getThreadOrThrow(threadId, ownerUserId); const runs = await this.persistence.listRuns(threadId); return { items: runs }; } @Get('threads/:threadId/queued-messages') - async listQueuedMessages(@Param('threadId') threadId: string) { - const thread = await this.persistence.getThreadById(threadId); - if (!thread) throw new NotFoundException({ error: 'thread_not_found' }); + async listQueuedMessages(@Param('threadId') threadId: string, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + const thread = await this.getThreadOrThrow(threadId, ownerUserId); const assignedAgentNodeId = typeof thread.assignedAgentNodeId === 'string' ? thread.assignedAgentNodeId.trim() : ''; if (!assignedAgentNodeId) { @@ -451,9 +514,10 @@ export class AgentsThreadsController { } @Delete('threads/:threadId/queued-messages') - async clearQueuedMessages(@Param('threadId') threadId: string) { - const thread = await this.persistence.getThreadById(threadId); - if (!thread) throw new NotFoundException({ error: 'thread_not_found' }); + async clearQueuedMessages(@Param('threadId') threadId: string, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + const thread = await this.getThreadOrThrow(threadId, ownerUserId); const assignedAgentNodeId = typeof thread.assignedAgentNodeId === 'string' ? thread.assignedAgentNodeId.trim() : ''; if (!assignedAgentNodeId) { @@ -491,9 +555,10 @@ export class AgentsThreadsController { } @Post('threads/:threadId/reminders/cancel') - async cancelThreadReminders(@Param('threadId') threadId: string) { - const thread = await this.persistence.getThreadById(threadId); - if (!thread) throw new NotFoundException({ error: 'thread_not_found' }); + async cancelThreadReminders(@Param('threadId') threadId: string, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + await this.getThreadOrThrow(threadId, ownerUserId); try { const result = await this.reminders.cancelThreadReminders({ threadId, emitMetrics: true }); @@ -506,13 +571,23 @@ export class AgentsThreadsController { } @Get('runs/:runId/messages') - async listRunMessages(@Param('runId') runId: string, @Query() query: ListRunMessagesQueryDto) { + async listRunMessages( + @Param('runId') runId: string, + @Query() query: ListRunMessagesQueryDto, + @CurrentPrincipal() principal: Principal | null, + ) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + await this.getRunOrThrow(runId, ownerUserId); const items = await this.persistence.listRunMessages(runId, query.type); return { items }; } @Get('runs/:runId/summary') - async getRunTimelineSummary(@Param('runId') runId: string) { + async getRunTimelineSummary(@Param('runId') runId: string, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + await this.getRunOrThrow(runId, ownerUserId); const summary = await this.runEvents.getRunSummary(runId); if (!summary) throw new NotFoundException('run_not_found'); return summary; @@ -524,7 +599,11 @@ export class AgentsThreadsController { @Query() query: RunTimelineEventsQueryDto, @Query('type') typeFilter?: string | string[], @Query('status') statusFilter?: string | string[], + @CurrentPrincipal() principal?: Principal | null, ) { + const currentPrincipal = this.requirePrincipal(principal ?? null); + const ownerUserId = currentPrincipal.userId; + await this.getRunOrThrow(runId, ownerUserId); const collect = (input?: string | string[]) => { if (!input) return [] as string[]; const values = Array.isArray(input) ? input : [input]; @@ -575,7 +654,11 @@ export class AgentsThreadsController { @Param('runId') runId: string, @Param('eventId') eventId: string, @Query() query: RunEventOutputQueryDto, + @CurrentPrincipal() principal: Principal | null, ) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + await this.getRunOrThrow(runId, ownerUserId); try { const snapshot = await this.runEvents.getToolOutputSnapshot({ runId, @@ -595,11 +678,25 @@ export class AgentsThreadsController { } @Patch('threads/:threadId') - async patchThread(@Param('threadId') threadId: string, @Body() body: PatchThreadBodyDto) { + async patchThread( + @Param('threadId') threadId: string, + @Body() body: PatchThreadBodyDto, + @CurrentPrincipal() principal: Principal | null, + ) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; const update: { summary?: string | null; status?: ThreadStatus } = {}; if (body.summary !== undefined) update.summary = body.summary; if (body.status !== undefined) update.status = body.status; - const result = await this.persistence.updateThread(threadId, update); + let result; + try { + result = await this.persistence.updateThread(threadId, update, { ownerUserId }); + } catch (error) { + if (error instanceof Error && error.message === 'thread_not_found') { + throw new NotFoundException({ error: 'thread_not_found' }); + } + throw error; + } if (result.status === 'closed' && result.previousStatus !== 'closed') { void this.cleanupCoordinator.closeThreadWithCascade(threadId); @@ -609,16 +706,19 @@ export class AgentsThreadsController { @Post('threads/:threadId/messages') @HttpCode(202) - async sendThreadMessage(@Param('threadId') threadId: string, @Body() body: unknown): Promise<{ ok: true }> { + async sendThreadMessage( + @Param('threadId') threadId: string, + @Body() body: unknown, + @CurrentPrincipal() principal: Principal | null, + ): Promise<{ ok: true }> { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; const text = this.extractMessageText(body); if (!text) { throw new BadRequestException({ error: 'bad_message_payload' }); } - const thread = await this.persistence.getThreadById(threadId); - if (!thread) { - throw new NotFoundException({ error: 'thread_not_found' }); - } + const thread = await this.getThreadOrThrow(threadId, ownerUserId); if (thread.status === 'closed') { throw new ConflictException({ error: 'thread_closed' }); } @@ -671,15 +771,19 @@ export class AgentsThreadsController { } @Get('threads/:threadId/metrics') - async getThreadMetrics(@Param('threadId') threadId: string) { + async getThreadMetrics(@Param('threadId') threadId: string, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + await this.getThreadOrThrow(threadId, ownerUserId); const metrics = await this.persistence.getThreadsMetrics([threadId]); return metrics[threadId] ?? { remindersCount: 0, containersCount: 0, activity: 'idle' as const, runsCount: 0 }; } @Post('runs/:runId/terminate') - async terminateRun(@Param('runId') runId: string) { - const run = await this.persistence.getRunById(runId); - if (!run) throw new NotFoundException('run_not_found'); + async terminateRun(@Param('runId') runId: string, @CurrentPrincipal() principal: Principal | null) { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; + const run = await this.getRunOrThrow(runId, ownerUserId); if (run.status !== 'running') { return { ok: true }; } diff --git a/packages/platform-server/src/auth/auth.controller.ts b/packages/platform-server/src/auth/auth.controller.ts new file mode 100644 index 000000000..2ff61d3d4 --- /dev/null +++ b/packages/platform-server/src/auth/auth.controller.ts @@ -0,0 +1,43 @@ +import { Controller, Get, Inject, Post, Query, Req, Res } from '@nestjs/common'; +import { IsString } from 'class-validator'; +import type { FastifyReply } from 'fastify'; +import { AuthService } from './auth.service'; +import { Public } from './public.decorator'; +import { CurrentPrincipal } from './principal.decorator'; +import type { AuthStatusResponse, Principal, RequestWithPrincipal } from './auth.types'; + +class OidcCallbackQueryDto { + @IsString() + state!: string; + + @IsString() + code!: string; +} + +@Controller('api/auth') +export class AuthController { + constructor(@Inject(AuthService) private readonly auth: AuthService) {} + + @Get('status') + @Public() + async status(@CurrentPrincipal() principal: Principal | null): Promise { + return this.auth.getAuthStatus(principal); + } + + @Get('login') + @Public() + async login(@Res({ passthrough: true }) reply: FastifyReply): Promise { + await this.auth.initiateLogin(reply); + } + + @Get('oidc/callback') + @Public() + async callback(@Query() query: OidcCallbackQueryDto, @Res({ passthrough: true }) reply: FastifyReply): Promise { + await this.auth.handleOidcCallback(query, reply); + } + + @Post('logout') + async logout(@Req() request: RequestWithPrincipal, @Res({ passthrough: true }) reply: FastifyReply): Promise { + await this.auth.logout(reply, request.sessionId ?? null); + } +} diff --git a/packages/platform-server/src/auth/auth.module.ts b/packages/platform-server/src/auth/auth.module.ts new file mode 100644 index 000000000..231aa4b26 --- /dev/null +++ b/packages/platform-server/src/auth/auth.module.ts @@ -0,0 +1,34 @@ +import { Global, Module } from '@nestjs/common'; +import { APP_GUARD } from '@nestjs/core'; +import { CoreModule } from '../core/core.module'; +import { AuthController } from './auth.controller'; +import { AuthService } from './auth.service'; +import { SessionService } from './session.service'; +import { UserService } from './user.service'; +import { OidcService } from './oidc.service'; +import { LoginStateStore } from './login-state.store'; +import { PrincipalGuard } from './principal.guard'; +import { AuthenticatedGuard } from './authenticated.guard'; + +@Global() +@Module({ + imports: [CoreModule], + controllers: [AuthController], + providers: [ + AuthService, + SessionService, + UserService, + OidcService, + LoginStateStore, + { + provide: APP_GUARD, + useClass: PrincipalGuard, + }, + { + provide: APP_GUARD, + useClass: AuthenticatedGuard, + }, + ], + exports: [AuthService, SessionService, UserService], +}) +export class AuthModule {} diff --git a/packages/platform-server/src/auth/auth.service.ts b/packages/platform-server/src/auth/auth.service.ts new file mode 100644 index 000000000..3533b9879 --- /dev/null +++ b/packages/platform-server/src/auth/auth.service.ts @@ -0,0 +1,169 @@ +import { BadRequestException, Inject, Injectable, Logger, UnauthorizedException } from '@nestjs/common'; +import type { FastifyReply, FastifyRequest } from 'fastify'; +import { createHash, randomBytes } from 'node:crypto'; +import { SessionService } from './session.service'; +import { UserService } from './user.service'; +import { OidcService } from './oidc.service'; +import { LoginStateStore } from './login-state.store'; +import { ConfigService, type AuthMode } from '../core/services/config.service'; +import type { AuthStatusResponse, Principal } from './auth.types'; + +const generateCodeVerifier = (): string => randomBytes(32).toString('base64url'); +const generateCodeChallenge = (verifier: string): string => createHash('sha256').update(verifier).digest('base64url'); +const generateNonce = (): string => randomBytes(16).toString('base64url'); +type TokenClaims = { + sub?: unknown; + email?: unknown; + name?: unknown; + preferred_username?: unknown; +}; + +@Injectable() +export class AuthService { + private readonly logger = new Logger(AuthService.name); + + constructor( + @Inject(ConfigService) private readonly config: ConfigService, + @Inject(SessionService) private readonly sessions: SessionService, + @Inject(UserService) private readonly users: UserService, + @Inject(OidcService) private readonly oidc: OidcService, + @Inject(LoginStateStore) private readonly loginStates: LoginStateStore, + ) {} + + get mode(): AuthMode { + return this.config.authMode; + } + + async resolveRequestContext(request: FastifyRequest): Promise<{ principal: Principal | null; sessionId: string | null }> { + if (this.mode === 'single_user') { + const principal = await this.getDefaultPrincipal(); + return { principal, sessionId: null }; + } + const sessionId = this.sessions.readSessionIdFromRequest(request); + if (!sessionId) return { principal: null, sessionId: null }; + const principal = await this.buildPrincipalFromSession(sessionId); + return { principal, sessionId: principal ? sessionId : null }; + } + + async resolvePrincipalFromCookieHeader(cookieHeader: string | undefined): Promise { + if (this.mode === 'single_user') { + return this.getDefaultPrincipal(); + } + const sessionId = this.sessions.readSessionIdFromCookieHeader(cookieHeader); + if (!sessionId) return null; + return this.buildPrincipalFromSession(sessionId); + } + + async getAuthStatus(principal: Principal | null): Promise { + if (this.mode === 'single_user') { + const defaultPrincipal = await this.getDefaultPrincipal(); + return { + mode: this.mode, + authenticated: true, + user: defaultPrincipal.user, + }; + } + return { + mode: this.mode, + authenticated: !!principal, + user: principal?.user ?? null, + }; + } + + async initiateLogin(reply: FastifyReply): Promise { + if (this.mode !== 'oidc') { + reply.status(204).send(); + return; + } + const codeVerifier = generateCodeVerifier(); + const codeChallenge = generateCodeChallenge(codeVerifier); + const nonce = generateNonce(); + const state = this.loginStates.create({ codeVerifier, nonce }); + const url = await this.oidc.getAuthorizationUrl({ + state, + nonce, + codeChallenge, + scopes: this.config.oidcScopes, + }); + reply.redirect(url); + } + + async handleOidcCallback(params: { state: string; code: string }, reply: FastifyReply): Promise { + if (this.mode !== 'oidc') { + throw new BadRequestException({ error: 'oidc_disabled' }); + } + const loginState = this.loginStates.consume(params.state); + if (!loginState) { + throw new BadRequestException({ error: 'invalid_state' }); + } + const tokenSet = await this.oidc.handleCallback({ + state: params.state, + code: params.code, + nonce: loginState.nonce, + codeVerifier: loginState.codeVerifier, + }); + const claims = tokenSet.claims() as TokenClaims; + const subject = typeof claims.sub === 'string' ? claims.sub : null; + if (!subject) { + throw new UnauthorizedException({ error: 'missing_subject' }); + } + const email = typeof claims.email === 'string' ? claims.email : null; + const name = + typeof claims.name === 'string' + ? claims.name + : typeof claims.preferred_username === 'string' + ? claims.preferred_username + : email; + const user = await this.users.upsertOidcUser({ + issuer: this.config.oidcIssuerUrl, + subject, + email, + name, + }); + const session = await this.sessions.create(user.id); + reply.header('Set-Cookie', this.sessions.serializeCookie(session.id, session.expiresAt)); + reply.redirect(this.config.oidcPostLoginRedirect); + } + + async logout(reply: FastifyReply, sessionId: string | null): Promise { + if (sessionId) { + await this.sessions.delete(sessionId); + } + reply.header('Set-Cookie', this.sessions.serializeClearCookie()); + reply.status(204).send(); + } + + private async buildPrincipalFromSession(sessionId: string): Promise { + const session = await this.sessions.get(sessionId); + if (!session) return null; + const user = await this.users.getById(session.userId); + if (!user) { + await this.sessions.delete(session.id); + return null; + } + return { + mode: this.mode, + userId: user.id, + sessionId, + user: { + id: user.id, + email: user.email ?? null, + name: user.name ?? null, + }, + }; + } + + private async getDefaultPrincipal(): Promise { + const user = await this.users.ensureDefaultUser(); + return { + mode: 'single_user', + userId: user.id, + sessionId: null, + user: { + id: user.id, + email: user.email ?? null, + name: user.name ?? null, + }, + }; + } +} diff --git a/packages/platform-server/src/auth/auth.types.ts b/packages/platform-server/src/auth/auth.types.ts new file mode 100644 index 000000000..6947594ed --- /dev/null +++ b/packages/platform-server/src/auth/auth.types.ts @@ -0,0 +1,25 @@ +import type { AuthMode } from '../core/services/config.service'; + +export type PrincipalUser = { + id: string; + email: string | null; + name: string | null; +}; + +export type Principal = { + mode: AuthMode; + userId: string; + user: PrincipalUser; + sessionId: string | null; +}; + +export type AuthStatusResponse = { + mode: AuthMode; + authenticated: boolean; + user: PrincipalUser | null; +}; + +export type RequestWithPrincipal = import('fastify').FastifyRequest & { + principal?: Principal | null; + sessionId?: string | null; +}; diff --git a/packages/platform-server/src/auth/authenticated.guard.ts b/packages/platform-server/src/auth/authenticated.guard.ts new file mode 100644 index 000000000..7804d0ab9 --- /dev/null +++ b/packages/platform-server/src/auth/authenticated.guard.ts @@ -0,0 +1,19 @@ +import { CanActivate, ExecutionContext, Inject, Injectable, UnauthorizedException } from '@nestjs/common'; +import { Reflector } from '@nestjs/core'; +import { AuthService } from './auth.service'; +import { IS_PUBLIC_KEY } from './public.decorator'; + +@Injectable() +export class AuthenticatedGuard implements CanActivate { + constructor(@Inject(AuthService) private readonly authService: AuthService, @Inject(Reflector) private readonly reflector: Reflector) {} + + async canActivate(context: ExecutionContext): Promise { + if (this.authService.mode === 'single_user') return true; + if (context.getType() !== 'http') return true; + const isPublic = this.reflector.getAllAndOverride(IS_PUBLIC_KEY, [context.getHandler(), context.getClass()]); + if (isPublic) return true; + const request = context.switchToHttp().getRequest<{ principal?: unknown }>(); + if (request.principal) return true; + throw new UnauthorizedException({ error: 'authentication_required' }); + } +} diff --git a/packages/platform-server/src/auth/login-state.store.ts b/packages/platform-server/src/auth/login-state.store.ts new file mode 100644 index 000000000..efc516b49 --- /dev/null +++ b/packages/platform-server/src/auth/login-state.store.ts @@ -0,0 +1,41 @@ +import { Injectable } from '@nestjs/common'; +import { randomUUID } from 'node:crypto'; + +type LoginRecord = { + codeVerifier: string; + nonce: string; + createdAt: number; +}; + +@Injectable() +export class LoginStateStore { + private readonly ttlMs = 10 * 60 * 1000; + private readonly records = new Map(); + + create(entry: { codeVerifier: string; nonce: string }): string { + this.evictExpired(); + const state = randomUUID(); + this.records.set(state, { ...entry, createdAt: Date.now() }); + return state; + } + + consume(state: string | undefined | null): LoginRecord | null { + if (!state) return null; + const record = this.records.get(state); + this.records.delete(state); + if (!record) return null; + if (Date.now() - record.createdAt > this.ttlMs) { + return null; + } + return record; + } + + private evictExpired(): void { + const now = Date.now(); + for (const [state, record] of this.records) { + if (now - record.createdAt > this.ttlMs) { + this.records.delete(state); + } + } + } +} diff --git a/packages/platform-server/src/auth/oidc.service.ts b/packages/platform-server/src/auth/oidc.service.ts new file mode 100644 index 000000000..23416640b --- /dev/null +++ b/packages/platform-server/src/auth/oidc.service.ts @@ -0,0 +1,85 @@ +import { Inject, Injectable } from '@nestjs/common'; +import * as openidClient from 'openid-client'; +import { ConfigService } from '../core/services/config.service'; + +export type TokenSetResult = { + claims(): Record; +}; + +type AuthorizationClient = { + authorizationUrl(params: { + state: string; + nonce: string; + scope: string; + redirect_uri: string; + code_challenge: string; + code_challenge_method: 'S256'; + }): string; + callback( + redirectUri: string, + parameters: { code: string; state: string }, + checks: { state: string; nonce: string; code_verifier: string }, + ): Promise; +}; + +type IssuerApi = { + discover(issuerUrl: string): Promise<{ Client: new (metadata: ClientMetadata) => AuthorizationClient }>; +}; + +type ClientMetadata = { + client_id: string; + client_secret?: string; + redirect_uris: [string]; + response_types: [string]; + token_endpoint_auth_method: 'client_secret_basic' | 'none'; +}; + +const issuerApi = (openidClient as unknown as { Issuer: IssuerApi }).Issuer; + +@Injectable() +export class OidcService { + private clientPromise: Promise | null = null; + + constructor(@Inject(ConfigService) private readonly config: ConfigService) {} + + async getAuthorizationUrl(params: { state: string; nonce: string; codeChallenge: string; scopes: string[] }): Promise { + const client = await this.ensureClient(); + return client.authorizationUrl({ + state: params.state, + nonce: params.nonce, + scope: params.scopes.join(' '), + redirect_uri: this.config.oidcRedirectUri, + code_challenge: params.codeChallenge, + code_challenge_method: 'S256', + }); + } + + async handleCallback(params: { state: string; code: string; nonce: string; codeVerifier: string }): Promise { + const client = await this.ensureClient(); + const checks = { state: params.state, nonce: params.nonce, code_verifier: params.codeVerifier }; + return client.callback( + this.config.oidcRedirectUri, + { code: params.code, state: params.state }, + checks, + ) as Promise; + } + + private async ensureClient(): Promise { + if (!this.clientPromise) { + this.clientPromise = this.createClient(); + } + return this.clientPromise; + } + + private async createClient(): Promise { + const issuer = await issuerApi.discover(this.config.oidcIssuerUrl); + const metadata: ClientMetadata = { + client_id: this.config.oidcClientId, + client_secret: this.config.oidcClientSecret || undefined, + redirect_uris: [this.config.oidcRedirectUri], + response_types: ['code'], + token_endpoint_auth_method: this.config.oidcClientSecret ? 'client_secret_basic' : 'none', + }; + return new issuer.Client(metadata); + } +} diff --git a/packages/platform-server/src/auth/principal.decorator.ts b/packages/platform-server/src/auth/principal.decorator.ts new file mode 100644 index 000000000..d4078a070 --- /dev/null +++ b/packages/platform-server/src/auth/principal.decorator.ts @@ -0,0 +1,7 @@ +import { createParamDecorator, ExecutionContext } from '@nestjs/common'; +import type { Principal } from './auth.types'; + +export const CurrentPrincipal = createParamDecorator((_data: unknown, ctx: ExecutionContext): Principal | null => { + const request = ctx.switchToHttp().getRequest<{ principal?: Principal | null }>(); + return request.principal ?? null; +}); diff --git a/packages/platform-server/src/auth/principal.guard.ts b/packages/platform-server/src/auth/principal.guard.ts new file mode 100644 index 000000000..73ae92e17 --- /dev/null +++ b/packages/platform-server/src/auth/principal.guard.ts @@ -0,0 +1,17 @@ +import { CanActivate, ExecutionContext, Inject, Injectable } from '@nestjs/common'; +import type { RequestWithPrincipal } from './auth.types'; +import { AuthService } from './auth.service'; + +@Injectable() +export class PrincipalGuard implements CanActivate { + constructor(@Inject(AuthService) private readonly authService: AuthService) {} + + async canActivate(context: ExecutionContext): Promise { + if (context.getType() !== 'http') return true; + const request = context.switchToHttp().getRequest(); + const { principal, sessionId } = await this.authService.resolveRequestContext(request); + request.principal = principal; + request.sessionId = sessionId; + return true; + } +} diff --git a/packages/platform-server/src/auth/public.decorator.ts b/packages/platform-server/src/auth/public.decorator.ts new file mode 100644 index 000000000..504bbbf06 --- /dev/null +++ b/packages/platform-server/src/auth/public.decorator.ts @@ -0,0 +1,4 @@ +import { SetMetadata } from '@nestjs/common'; + +export const IS_PUBLIC_KEY = 'auth:isPublic'; +export const Public = () => SetMetadata(IS_PUBLIC_KEY, true); diff --git a/packages/platform-server/src/auth/session.service.ts b/packages/platform-server/src/auth/session.service.ts new file mode 100644 index 000000000..836f076d5 --- /dev/null +++ b/packages/platform-server/src/auth/session.service.ts @@ -0,0 +1,121 @@ +import { Inject, Injectable, Logger } from '@nestjs/common'; +import type { FastifyRequest } from 'fastify'; +import { createHmac, randomUUID, timingSafeEqual } from 'node:crypto'; +import { parse as parseCookie, serialize as serializeCookie } from 'cookie'; +import type { SerializeOptions } from 'cookie'; +import { PrismaService } from '../core/services/prisma.service'; +import { ConfigService } from '../core/services/config.service'; + +const SESSION_TTL_MS = 30 * 24 * 60 * 60 * 1000; // 30 days +const SESSION_COOKIE = 'agyn_session'; + +export type SessionRecord = { + id: string; + userId: string; + expiresAt: Date; +}; + +@Injectable() +export class SessionService { + private readonly logger = new Logger(SessionService.name); + private readonly cookieOptions: SerializeOptions; + + constructor(@Inject(PrismaService) private readonly prisma: PrismaService, @Inject(ConfigService) private readonly config: ConfigService) { + this.cookieOptions = { + path: '/', + httpOnly: true, + sameSite: 'lax', + secure: this.config.isProduction, + }; + } + + async create(userId: string): Promise { + const id = randomUUID(); + const expiresAt = new Date(Date.now() + SESSION_TTL_MS); + const session = await this.prisma.getClient().session.create({ + data: { + id, + userId, + expiresAt, + }, + select: { id: true, userId: true, expiresAt: true }, + }); + return session; + } + + async get(sessionId: string): Promise { + const session = await this.prisma.getClient().session.findUnique({ + where: { id: sessionId }, + select: { id: true, userId: true, expiresAt: true }, + }); + if (!session) return null; + if (session.expiresAt.getTime() <= Date.now()) { + await this.safeDelete(session.id); + return null; + } + return session; + } + + async delete(sessionId: string): Promise { + await this.safeDelete(sessionId); + } + + readSessionIdFromRequest(request: FastifyRequest): string | null { + return this.readSessionIdFromCookieHeader(request.headers.cookie); + } + + readSessionIdFromCookieHeader(cookieHeader: string | undefined): string | null { + if (!cookieHeader) return null; + const cookies = parseCookie(cookieHeader); + const token = cookies[SESSION_COOKIE]; + return this.decodeCookieValue(token); + } + + serializeCookie(sessionId: string, expiresAt: Date): string { + const signed = this.encodeCookieValue(sessionId); + return serializeCookie(SESSION_COOKIE, signed, { ...this.cookieOptions, expires: expiresAt }); + } + + serializeClearCookie(): string { + return serializeCookie(SESSION_COOKIE, '', { + ...this.cookieOptions, + expires: new Date(0), + }); + } + + private encodeCookieValue(sessionId: string): string { + const signature = this.sign(sessionId); + return `${sessionId}.${signature}`; + } + + private decodeCookieValue(value: string | undefined): string | null { + if (!value) return null; + const [sessionId, signature] = value.split('.'); + if (!sessionId || !signature) return null; + const expected = this.sign(sessionId); + if (!this.safeEqual(signature, expected)) { + this.logger.warn('Session cookie signature mismatch'); + return null; + } + return sessionId; + } + + private sign(value: string): string { + return createHmac('sha256', this.config.sessionSecret).update(value).digest('hex'); + } + + private safeEqual(a: string, b: string): boolean { + const bufA = Buffer.from(a); + const bufB = Buffer.from(b); + if (bufA.length !== bufB.length) return false; + return timingSafeEqual(bufA, bufB); + } + + private async safeDelete(sessionId: string): Promise { + try { + await this.prisma.getClient().session.delete({ where: { id: sessionId } }); + } catch (error) { + this.logger.debug(`Session delete ignored: ${(error as Error).message}`); + } + } +} diff --git a/packages/platform-server/src/auth/user.service.ts b/packages/platform-server/src/auth/user.service.ts new file mode 100644 index 000000000..e80e3e7d6 --- /dev/null +++ b/packages/platform-server/src/auth/user.service.ts @@ -0,0 +1,66 @@ +import { Inject, Injectable } from '@nestjs/common'; +import type { Prisma, User } from '@prisma/client'; +import { PrismaService } from '../core/services/prisma.service'; + +const DEFAULT_USER_ID = '00000000-0000-0000-0000-000000000001'; +type UserSummary = Pick; + +@Injectable() +export class UserService { + private readonly includeSelect = { id: true, email: true, name: true } satisfies Prisma.UserSelect; + private defaultUserPromise: Promise | null = null; + + constructor(@Inject(PrismaService) private readonly prisma: PrismaService) {} + + get defaultUserId(): string { + return DEFAULT_USER_ID; + } + + async ensureDefaultUser(): Promise { + if (!this.defaultUserPromise) { + this.defaultUserPromise = this.prisma.getClient().user.upsert({ + where: { id: DEFAULT_USER_ID }, + update: { updatedAt: new Date() }, + create: { + id: DEFAULT_USER_ID, + email: 'default@local', + name: 'Default User', + createdAt: new Date(), + updatedAt: new Date(), + }, + select: this.includeSelect, + }); + } + return this.defaultUserPromise; + } + + async getById(id: string): Promise { + const user = await this.prisma.getClient().user.findUnique({ where: { id }, select: this.includeSelect }); + return user; + } + + async upsertOidcUser(params: { issuer: string; subject: string; email?: string | null; name?: string | null }): Promise { + const now = new Date(); + const identityKey: Prisma.UserWhereUniqueInput = { + oidcIssuer_oidcSubject: { oidcIssuer: params.issuer, oidcSubject: params.subject }, + }; + const user = await this.prisma.getClient().user.upsert({ + where: identityKey, + update: { + email: params.email ?? undefined, + name: params.name ?? undefined, + updatedAt: now, + }, + create: { + email: params.email ?? null, + name: params.name ?? null, + oidcIssuer: params.issuer, + oidcSubject: params.subject, + createdAt: now, + updatedAt: now, + }, + select: this.includeSelect, + }); + return user; + } +} diff --git a/packages/platform-server/src/bootstrap/app.module.ts b/packages/platform-server/src/bootstrap/app.module.ts index 5a7f243b4..286b80535 100644 --- a/packages/platform-server/src/bootstrap/app.module.ts +++ b/packages/platform-server/src/bootstrap/app.module.ts @@ -12,6 +12,7 @@ import { LLMModule } from '../llm/llm.module'; import { LLMProvisioner } from '../llm/provisioners/llm.provisioner'; import { OnboardingModule } from '../onboarding/onboarding.module'; import { UserProfileModule } from '../user-profile/user-profile.module'; +import { AuthModule } from '../auth/auth.module'; type PinoLoggerModule = { forRoot: (options: { @@ -55,6 +56,7 @@ const createLoggerModule = (): DynamicModule => { imports: [ createLoggerModule(), CoreModule, + AuthModule, EventsModule, InfraModule, GraphApiModule, diff --git a/packages/platform-server/src/core/services/config.service.ts b/packages/platform-server/src/core/services/config.service.ts index 903d63d64..5004eb166 100644 --- a/packages/platform-server/src/core/services/config.service.ts +++ b/packages/platform-server/src/core/services/config.service.ts @@ -2,8 +2,10 @@ import { Injectable } from '@nestjs/common'; import * as dotenv from 'dotenv'; import { z } from 'zod'; dotenv.config(); +export type AuthMode = 'single_user' | 'oidc'; -export const configSchema = z.object({ +export const configSchema = z + .object({ // GitHub settings are optional to allow dev boot without GitHub githubAppId: z.string().min(1).optional(), githubAppPrivateKey: z.string().min(1).optional(), @@ -159,14 +161,65 @@ export const configSchema = z.object({ // CORS origins (comma-separated in env; parsed to string[]) corsOrigins: z .string() - .default("") + .default('') .transform((s) => s - .split(",") + .split(',') .map((x) => x.trim()) .filter((x) => !!x), ), -}); + authMode: z.enum(['single_user', 'oidc']).default('single_user'), + sessionSecret: z + .string() + .default('dev-session-secret-change-me-0123456789abcdef0123456789abcdef') + .transform((value) => value.trim()) + .refine((value) => value.length >= 32, 'SESSION_SECRET must be at least 32 characters'), + oidcIssuerUrl: z.string().url().optional(), + oidcClientId: z.string().optional(), + oidcClientSecret: z.string().optional(), + oidcRedirectUri: z.string().url().optional(), + oidcScopes: z + .string() + .default('openid profile email') + .transform((value) => + value + .split(/[\s,]+/) + .map((scope) => scope.trim()) + .filter((scope) => scope.length > 0), + ), + oidcPostLoginRedirect: z + .string() + .default('/') + .transform((value) => { + const trimmed = value.trim(); + return trimmed.length === 0 ? '/' : trimmed; + }), +}) + .superRefine((value, ctx) => { + if (value.authMode === 'oidc') { + if (!value.oidcIssuerUrl) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: 'OIDC_ISSUER_URL is required when AUTH_MODE=oidc', + path: ['oidcIssuerUrl'], + }); + } + if (!value.oidcClientId) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: 'OIDC_CLIENT_ID is required when AUTH_MODE=oidc', + path: ['oidcClientId'], + }); + } + if (!value.oidcRedirectUri) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: 'OIDC_REDIRECT_URI is required when AUTH_MODE=oidc', + path: ['oidcRedirectUri'], + }); + } + } + }); export type Config = z.infer; @@ -370,6 +423,45 @@ export class ConfigService implements Config { get nixRepoAllowlist(): string[] { return this.params.nixRepoAllowlist ?? []; } + get authMode(): AuthMode { + return this.params.authMode; + } + get sessionSecret(): string { + return this.params.sessionSecret; + } + get oidcIssuerUrl(): string { + if (this.params.authMode !== 'oidc' || !this.params.oidcIssuerUrl) { + throw new Error('OIDC issuer not configured'); + } + return this.params.oidcIssuerUrl; + } + get oidcClientId(): string { + if (this.params.authMode !== 'oidc' || !this.params.oidcClientId) { + throw new Error('OIDC client id not configured'); + } + return this.params.oidcClientId; + } + get oidcClientSecret(): string | undefined { + return this.params.oidcClientSecret ?? undefined; + } + get oidcRedirectUri(): string { + if (this.params.authMode !== 'oidc' || !this.params.oidcRedirectUri) { + throw new Error('OIDC redirect URI not configured'); + } + return this.params.oidcRedirectUri; + } + get oidcScopes(): string[] { + return this.params.oidcScopes; + } + get oidcPostLoginRedirect(): string { + return this.params.oidcPostLoginRedirect; + } + get isProduction(): boolean { + const envName = (process.env.NODE_ENV ?? '').toLowerCase(); + if (envName === 'production') return true; + const agentsEnv = (process.env.AGENTS_ENV ?? '').toLowerCase(); + return agentsEnv === 'production'; + } // No global messaging adapter config in Slack-only v1 @@ -420,6 +512,14 @@ export class ConfigService implements Config { ncpsAuthToken: process.env.NCPS_AUTH_TOKEN, agentsDatabaseUrl: process.env.AGENTS_DATABASE_URL, corsOrigins: process.env.CORS_ORIGINS, + authMode: process.env.AUTH_MODE, + sessionSecret: process.env.SESSION_SECRET, + oidcIssuerUrl: process.env.OIDC_ISSUER_URL, + oidcClientId: process.env.OIDC_CLIENT_ID, + oidcClientSecret: process.env.OIDC_CLIENT_SECRET, + oidcRedirectUri: process.env.OIDC_REDIRECT_URI, + oidcScopes: process.env.OIDC_SCOPES, + oidcPostLoginRedirect: process.env.OIDC_POST_LOGIN_REDIRECT, }); const config = new ConfigService().init(parsed); ConfigService.register(config); diff --git a/packages/platform-server/src/core/services/startupRecovery.service.ts b/packages/platform-server/src/core/services/startupRecovery.service.ts index 9f8d4b993..378aa5997 100644 --- a/packages/platform-server/src/core/services/startupRecovery.service.ts +++ b/packages/platform-server/src/core/services/startupRecovery.service.ts @@ -12,6 +12,7 @@ type RecoveredRun = { status: RunStatus; createdAt: Date; updatedAt: Date; + ownerUserId: string; }; type RecoveredReminder = { @@ -132,9 +133,16 @@ export class StartupRecoveryService implements OnApplicationBootstrap { const updated = await runDelegate.findMany({ where: { id: { in: ids }, status: RunStatusEnum.terminated }, - select: { id: true, threadId: true, status: true, createdAt: true, updatedAt: true }, + select: { id: true, threadId: true, status: true, createdAt: true, updatedAt: true, thread: { select: { ownerUserId: true } } }, }); - return updated.map((run) => ({ ...run })); + return updated.map((run) => ({ + id: run.id, + threadId: run.threadId, + status: run.status, + createdAt: run.createdAt, + updatedAt: run.updatedAt, + ownerUserId: run.thread?.ownerUserId ?? '', + })); } private async completePendingReminders(tx: TransactionClient): Promise { @@ -211,9 +219,14 @@ export class StartupRecoveryService implements OnApplicationBootstrap { for (const run of runs) { metricThreads.add(run.threadId); + if (!run.ownerUserId) { + this.logger.warn('Skipping run_status_changed emission due to missing owner', { runId: run.id, threadId: run.threadId }); + continue; + } try { this.eventsBus.emitRunStatusChanged({ threadId: run.threadId, + ownerUserId: run.ownerUserId, run: { id: run.id, status: runStatus, diff --git a/packages/platform-server/src/events/events-bus.service.ts b/packages/platform-server/src/events/events-bus.service.ts index 30c6f6fc9..524d423e6 100644 --- a/packages/platform-server/src/events/events-bus.service.ts +++ b/packages/platform-server/src/events/events-bus.service.ts @@ -37,6 +37,7 @@ export type ThreadBroadcast = { summary: string | null; status: ThreadStatus; createdAt: Date; + ownerUserId: string; parentId?: string | null; channelNodeId?: string | null; assignedAgentNodeId?: string | null; @@ -51,6 +52,12 @@ export type MessageBroadcast = { runId?: string; }; +export type MessageCreatedEvent = { + threadId: string; + ownerUserId: string; + message: MessageBroadcast; +}; + export type ThreadMetricsEvent = { threadId: string; }; @@ -61,6 +68,7 @@ export type ThreadMetricsAncestorsEvent = { export type RunStatusBroadcast = { threadId: string; + ownerUserId: string; run: { id: string; status: RunStatus; @@ -77,7 +85,7 @@ type EventsBusEvents = { node_state: [NodeStateBusEvent]; thread_created: [ThreadBroadcast]; thread_updated: [ThreadBroadcast]; - message_created: [{ threadId: string; message: MessageBroadcast }]; + message_created: [MessageCreatedEvent]; run_status_changed: [RunStatusBroadcast]; thread_metrics: [ThreadMetricsEvent]; thread_metrics_ancestors: [ThreadMetricsAncestorsEvent]; @@ -174,14 +182,14 @@ export class EventsBusService implements OnModuleDestroy { this.emitter.emit('thread_updated', thread); } - subscribeToMessageCreated(listener: (payload: { threadId: string; message: MessageBroadcast }) => void): () => void { + subscribeToMessageCreated(listener: (payload: MessageCreatedEvent) => void): () => void { this.emitter.on('message_created', listener); return () => { this.emitter.off('message_created', listener); }; } - emitMessageCreated(payload: { threadId: string; message: MessageBroadcast }): void { + emitMessageCreated(payload: MessageCreatedEvent): void { this.emitter.emit('message_created', payload); } diff --git a/packages/platform-server/src/gateway/graph.socket.gateway.ts b/packages/platform-server/src/gateway/graph.socket.gateway.ts index 8eb04c8b8..725c89ed9 100644 --- a/packages/platform-server/src/gateway/graph.socket.gateway.ts +++ b/packages/platform-server/src/gateway/graph.socket.gateway.ts @@ -3,10 +3,10 @@ import type { IncomingHttpHeaders, Server as HTTPServer } from 'http'; import { Server as SocketIOServer, type ServerOptions, type Socket } from 'socket.io'; import { z } from 'zod'; import { LiveGraphRuntime } from '../graph-core/liveGraph.manager'; -import type { ThreadStatus, MessageKind, RunStatus } from '@prisma/client'; +import type { MessageKind } from '@prisma/client'; import { EventsBusService, - type MessageBroadcast, + type MessageCreatedEvent, type NodeStateBusEvent, type ReminderCountEvent as ReminderCountBusEvent, type RunEventBroadcast, @@ -19,6 +19,8 @@ import { import type { ToolOutputChunkPayload, ToolOutputTerminalPayload } from '../events/run-events.service'; import { ThreadsMetricsService } from '../agents/threads.metrics.service'; import { PrismaService } from '../core/services/prisma.service'; +import { ConfigService } from '../core/services/config.service'; +import { AuthService } from '../auth/auth.service'; // Strict outbound event payloads export const NodeStatusEventSchema = z @@ -101,6 +103,15 @@ function toDate(value: string): Date | null { return Number.isNaN(ts.getTime()) ? null : ts; } +const RoomSchema = z.union([ + z.literal('threads'), + z.literal('graph'), + z.string().regex(/^thread:[0-9a-z-]{1,64}$/i), + z.string().regex(/^run:[0-9a-z-]{1,64}$/i), + z.string().regex(/^node:[0-9a-z-]{1,64}$/i), +]); +const SubscribeSchema = z.object({ rooms: z.array(RoomSchema).optional(), room: RoomSchema.optional() }).strict(); + @Injectable({ scope: Scope.DEFAULT }) export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { private readonly logger = new Logger(GraphSocketGateway.name); @@ -110,13 +121,19 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { private metricsTimer: NodeJS.Timeout | null = null; private readonly COALESCE_MS = 100; private readonly cleanup: Array<() => void> = []; + private readonly allowedOrigins: string[]; + private readonly threadOwnerCache = new Map(); constructor( @Inject(LiveGraphRuntime) private readonly runtime: LiveGraphRuntime, @Inject(ThreadsMetricsService) private readonly metrics: ThreadsMetricsService, @Inject(PrismaService) private readonly prismaService: PrismaService, @Inject(EventsBusService) private readonly eventsBus: EventsBusService, - ) {} + @Inject(ConfigService) private readonly config: ConfigService, + @Inject(AuthService) private readonly authService: AuthService, + ) { + this.allowedOrigins = this.config.corsOrigins ?? []; + } onModuleInit(): void { this.cleanup.push(this.eventsBus.subscribeToRunEvents(this.handleRunEvent)); @@ -151,55 +168,26 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { const options: Partial = { path: '/socket.io', transports: ['websocket'] as ServerOptions['transports'], - cors: { origin: '*' }, - allowRequest: (_req, callback) => { - callback(null, true); + cors: { + origin: this.allowedOrigins.length ? this.allowedOrigins : true, + credentials: true, }, - }; - this.io = new SocketIOServer(server, options); - this.io.on('connection', (socket: Socket) => { - // Room subscription - const RoomSchema = z.union([ - z.literal('threads'), - z.literal('graph'), - z.string().regex(/^thread:[0-9a-z-]{1,64}$/i), - z.string().regex(/^run:[0-9a-z-]{1,64}$/i), - z.string().regex(/^node:[0-9a-z-]{1,64}$/i), - ]); - const SubscribeSchema = z - .object({ rooms: z.array(RoomSchema).optional(), room: RoomSchema.optional() }) - .strict(); - socket.on('subscribe', (payload: unknown, ack?: (response: unknown) => void) => { - const parsed = SubscribeSchema.safeParse(payload); - if (!parsed.success) { - const details = parsed.error.issues.map((issue) => ({ - path: issue.path, - message: issue.message, - code: issue.code, - })); - this.logger.warn( - `GraphSocketGateway: subscribe invalid${this.formatContext({ socketId: socket.id, issues: details })}`, - ); - if (typeof ack === 'function') { - ack({ ok: false, error: 'invalid_payload', issues: details }); - } + allowRequest: (req, callback) => { + if (this.allowedOrigins.length === 0) { + callback(null, true); return; } - const p = parsed.data; - const rooms: string[] = p.rooms ?? (p.room ? [p.room] : []); - for (const r of rooms) if (r.length > 0) socket.join(r); - if (typeof ack === 'function') { - ack({ ok: true, rooms }); + const originHeader = typeof req.headers.origin === 'string' ? req.headers.origin : undefined; + if (!originHeader || this.allowedOrigins.includes(originHeader)) { + callback(null, true); + return; } - }); - socket.on('error', (e: unknown) => { - this.logger.warn( - `GraphSocketGateway: socket error${this.formatContext({ - socketId: socket.id, - error: this.toSafeError(e), - })}`, - ); - }); + callback('forbidden_origin', false); + }, + }; + this.io = new SocketIOServer(server, options); + this.io.on('connection', (socket: Socket) => { + void this.initializeSocket(socket); }); this.initialized = true; // Wire runtime status events to socket broadcast @@ -207,6 +195,88 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { return this; } + private async initializeSocket(socket: Socket): Promise { + try { + const principal = await this.authService.resolvePrincipalFromCookieHeader( + socket.request.headers.cookie, + ); + if (!principal) { + this.logger.warn( + `GraphSocketGateway: unauthorized connection${this.formatContext({ socketId: socket.id })}`, + ); + socket.emit('error', { error: 'unauthorized' }); + socket.disconnect(true); + return; + } + this.setupSocketHandlers(socket, principal.userId); + } catch (error) { + this.logger.warn( + `GraphSocketGateway: connection setup failed${this.formatContext({ + socketId: socket.id, + error: this.toSafeError(error), + })}`, + ); + socket.disconnect(true); + } + } + + private setupSocketHandlers(socket: Socket, userId: string): void { + socket.on('subscribe', (payload: unknown, ack?: (response: unknown) => void) => { + const parsed = SubscribeSchema.safeParse(payload); + if (!parsed.success) { + const details = parsed.error.issues.map((issue) => ({ + path: issue.path, + message: issue.message, + code: issue.code, + })); + this.logger.warn( + `GraphSocketGateway: subscribe invalid${this.formatContext({ socketId: socket.id, issues: details })}`, + ); + if (typeof ack === 'function') { + ack({ ok: false, error: 'invalid_payload', issues: details }); + } + return; + } + const request = parsed.data; + const requestedRooms: string[] = request.rooms ?? (request.room ? [request.room] : []); + const joined: string[] = []; + for (const room of requestedRooms) { + if (!room) continue; + const resolved = this.resolveRoomForUser(room, userId); + if (!resolved) continue; + socket.join(resolved); + joined.push(room); + } + if (typeof ack === 'function') { + ack({ ok: true, rooms: joined }); + } + }); + socket.on('error', (e: unknown) => { + this.logger.warn( + `GraphSocketGateway: socket error${this.formatContext({ + socketId: socket.id, + error: this.toSafeError(e), + })}`, + ); + }); + } + + private resolveRoomForUser(room: string, userId: string): string | null { + if (this.isThreadScopedRoom(room)) { + if (!userId) return null; + return this.formatUserRoom(userId, room); + } + return room; + } + + private isThreadScopedRoom(room: string): boolean { + return room === 'threads' || room.startsWith('thread:') || room.startsWith('run:'); + } + + private formatUserRoom(userId: string, room: string): string { + return `user:${userId}:${room}`; + } + private readonly handleRunEvent = (payload: RunEventBusPayload): void => { const event = payload.event; if (!event) { @@ -218,14 +288,12 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { ); return; } - try { - const broadcast: RunEventBroadcast = { - runId: event.runId, - mutation: payload.mutation, - event, - }; - this.emitRunEvent(event.runId, event.threadId, broadcast); - } catch (err) { + const broadcast: RunEventBroadcast = { + runId: event.runId, + mutation: payload.mutation, + event, + }; + void this.emitRunEvent(event.runId, event.threadId, broadcast).catch((err) => { this.logger.warn( `GraphSocketGateway failed to emit run event${this.formatContext({ eventId: payload.eventId, @@ -233,7 +301,7 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { error: this.toSafeError(err), })}`, ); - } + }); }; private readonly handleToolOutputChunk = (payload: ToolOutputChunkPayload): void => { @@ -247,8 +315,8 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { ); return; } - try { - this.emitToolOutputChunk({ + void this + .emitToolOutputChunk({ runId: payload.runId, threadId: payload.threadId, eventId: payload.eventId, @@ -257,15 +325,15 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { source: payload.source, ts, data: payload.data, + }) + .catch((err) => { + this.logger.warn( + `GraphSocketGateway failed to emit tool_output_chunk${this.formatContext({ + eventId: payload.eventId, + error: this.toSafeError(err), + })}`, + ); }); - } catch (err) { - this.logger.warn( - `GraphSocketGateway failed to emit tool_output_chunk${this.formatContext({ - eventId: payload.eventId, - error: this.toSafeError(err), - })}`, - ); - } }; private readonly handleToolOutputTerminal = (payload: ToolOutputTerminalPayload): void => { @@ -279,8 +347,8 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { ); return; } - try { - this.emitToolOutputTerminal({ + void this + .emitToolOutputTerminal({ runId: payload.runId, threadId: payload.threadId, eventId: payload.eventId, @@ -293,15 +361,15 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { savedPath: payload.savedPath ?? undefined, message: payload.message ?? undefined, ts, + }) + .catch((err) => { + this.logger.warn( + `GraphSocketGateway failed to emit tool_output_terminal${this.formatContext({ + eventId: payload.eventId, + error: this.toSafeError(err), + })}`, + ); }); - } catch (err) { - this.logger.warn( - `GraphSocketGateway failed to emit tool_output_terminal${this.formatContext({ - eventId: payload.eventId, - error: this.toSafeError(err), - })}`, - ); - } }; private readonly handleReminderCount = (payload: ReminderCountBusEvent): void => { @@ -384,7 +452,7 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { } }; - private readonly handleMessageCreated = (payload: { threadId: string; message: MessageBroadcast }): void => { + private readonly handleMessageCreated = (payload: MessageCreatedEvent): void => { try { this.logger.log( `new message${this.formatContext({ @@ -394,7 +462,7 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { runId: payload.message.runId ?? null, })}`, ); - this.emitMessageCreated(payload.threadId, payload.message); + this.emitMessageCreated(payload.threadId, payload.ownerUserId, payload.message); } catch (err) { this.logger.warn( `GraphSocketGateway failed to emit message_created${this.formatContext({ @@ -408,7 +476,7 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { private readonly handleRunStatusChanged = (payload: RunStatusBroadcast): void => { try { - this.emitRunStatusChanged(payload.threadId, payload.run); + this.emitRunStatusChanged(payload); } catch (err) { this.logger.warn( `GraphSocketGateway failed to emit run_status_changed${this.formatContext({ @@ -508,51 +576,43 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { } // Threads realtime events - emitThreadCreated(thread: { - id: string; - alias: string; - summary: string | null; - status: ThreadStatus; - createdAt: Date; - parentId?: string | null; - channelNodeId?: string | null; - }) { + emitThreadCreated(thread: ThreadBroadcast) { + this.rememberThreadOwner(thread.id, thread.ownerUserId); const payload = { thread: { ...thread, createdAt: thread.createdAt.toISOString() } }; - this.emitToRooms(['threads'], 'thread_created', payload); + this.emitToUserRooms(thread.ownerUserId, ['threads'], 'thread_created', payload); } - emitThreadUpdated(thread: { - id: string; - alias: string; - summary: string | null; - status: ThreadStatus; - createdAt: Date; - parentId?: string | null; - channelNodeId?: string | null; - }) { + emitThreadUpdated(thread: ThreadBroadcast) { + this.rememberThreadOwner(thread.id, thread.ownerUserId); const payload = { thread: { ...thread, createdAt: thread.createdAt.toISOString() } }; - this.emitToRooms(['threads'], 'thread_updated', payload); + this.emitToUserRooms(thread.ownerUserId, ['threads', `thread:${thread.id}`], 'thread_updated', payload); } - emitMessageCreated(threadId: string, message: { id: string; kind: MessageKind; text: string | null; source: import('type-fest').JsonValue | unknown; createdAt: Date; runId?: string }) { + emitMessageCreated( + threadId: string, + ownerUserId: string, + message: { id: string; kind: MessageKind; text: string | null; source: import('type-fest').JsonValue | unknown; createdAt: Date; runId?: string }, + ) { const payload = { threadId, message: { ...message, createdAt: message.createdAt.toISOString() } }; - this.emitToRooms([`thread:${threadId}`], 'message_created', payload); + this.rememberThreadOwner(threadId, ownerUserId); + this.emitToUserRooms(ownerUserId, [`thread:${threadId}`], 'message_created', payload); } - emitRunStatusChanged(threadId: string, run: { id: string; status: RunStatus; createdAt: Date; updatedAt: Date }) { - const payload = { - threadId, + emitRunStatusChanged(payload: RunStatusBroadcast) { + const eventPayload = { + threadId: payload.threadId, run: { - ...run, - threadId, - createdAt: run.createdAt.toISOString(), - updatedAt: run.updatedAt.toISOString(), + ...payload.run, + threadId: payload.threadId, + createdAt: payload.run.createdAt.toISOString(), + updatedAt: payload.run.updatedAt.toISOString(), }, }; - this.emitToRooms([`thread:${threadId}`, `run:${run.id}`], 'run_status_changed', payload); + this.rememberThreadOwner(payload.threadId, payload.ownerUserId); + this.emitToUserRooms(payload.ownerUserId, [`thread:${payload.threadId}`, `run:${payload.run.id}`], 'run_status_changed', eventPayload); } - emitRunEvent(runId: string, threadId: string, payload: RunEventBroadcast) { + async emitRunEvent(runId: string, threadId: string, payload: RunEventBroadcast) { const eventName = payload.mutation === 'update' ? 'run_event_updated' : 'run_event_appended'; - this.emitToRooms([`run:${runId}`, `thread:${threadId}`], eventName, payload); + await this.emitThreadRooms(threadId, [`run:${runId}`, `thread:${threadId}`], eventName, payload); } - emitToolOutputChunk(payload: { + async emitToolOutputChunk(payload: { runId: string; threadId: string; eventId: string; @@ -579,9 +639,9 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { ); return; } - this.emitToRooms([`run:${eventPayload.runId}`, `thread:${eventPayload.threadId}`], 'tool_output_chunk', eventPayload); + await this.emitThreadRooms(eventPayload.threadId, [`run:${eventPayload.runId}`, `thread:${eventPayload.threadId}`], 'tool_output_chunk', eventPayload); } - emitToolOutputTerminal(payload: { + async emitToolOutputTerminal(payload: { runId: string; threadId: string; eventId: string; @@ -616,7 +676,7 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { ); return; } - this.emitToRooms([`run:${eventPayload.runId}`, `thread:${eventPayload.threadId}`], 'tool_output_terminal', eventPayload); + await this.emitThreadRooms(eventPayload.threadId, [`run:${eventPayload.runId}`, `thread:${eventPayload.threadId}`], 'tool_output_terminal', eventPayload); } private flushMetricsQueue = async () => { // De-duplicate pending thread IDs per flush (preserve insertion order) @@ -629,10 +689,13 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { for (const id of ids) { const m = map[id]; if (!m) continue; + const ownerUserId = await this.getThreadOwnerId(id); + if (!ownerUserId) continue; const activityPayload = { threadId: id, activity: m.activity }; - this.emitToRooms(['threads', `thread:${id}`], 'thread_activity_changed', activityPayload); const remindersPayload = { threadId: id, remindersCount: m.remindersCount }; - this.emitToRooms(['threads', `thread:${id}`], 'thread_reminders_count', remindersPayload); + const rooms = ['threads', `thread:${id}`]; + this.emitToUserRooms(ownerUserId, rooms, 'thread_activity_changed', activityPayload); + this.emitToUserRooms(ownerUserId, rooms, 'thread_reminders_count', remindersPayload); } } catch (e) { this.logger.error(`flushMetricsQueue error${this.formatContext({ error: this.toSafeError(e) })}`); @@ -702,6 +765,41 @@ export class GraphSocketGateway implements OnModuleInit, OnModuleDestroy { } } + private emitToUserRooms(userId: string, rooms: string[], event: string, payload: unknown): void { + if (!userId) return; + const resolved = rooms.map((room) => this.formatUserRoom(userId, room)); + this.emitToRooms(resolved, event, payload); + } + + private async emitThreadRooms( + threadId: string, + rooms: string[], + event: string, + payload: unknown, + ): Promise { + const ownerUserId = await this.getThreadOwnerId(threadId); + if (!ownerUserId) return; + this.emitToUserRooms(ownerUserId, rooms, event, payload); + } + + private async getThreadOwnerId(threadId: string): Promise { + if (!threadId) return null; + const cached = this.threadOwnerCache.get(threadId); + if (cached) return cached; + const prisma = this.prismaService.getClient(); + const repository = prisma?.thread; + if (!repository?.findUnique) return null; + const row = await repository.findUnique({ where: { id: threadId }, select: { ownerUserId: true } }); + if (!row?.ownerUserId) return null; + this.threadOwnerCache.set(threadId, row.ownerUserId); + return row.ownerUserId; + } + + private rememberThreadOwner(threadId: string, ownerUserId: string | null | undefined): void { + if (!threadId || !ownerUserId) return; + this.threadOwnerCache.set(threadId, ownerUserId); + } + private formatContext(context: Record): string { return ` ${JSON.stringify(context)}`; } diff --git a/packages/platform-server/src/graph/controllers/memory.controller.ts b/packages/platform-server/src/graph/controllers/memory.controller.ts index a814bd97b..63233351e 100644 --- a/packages/platform-server/src/graph/controllers/memory.controller.ts +++ b/packages/platform-server/src/graph/controllers/memory.controller.ts @@ -1,7 +1,24 @@ -import { Body, Controller, Delete, Get, HttpCode, HttpException, HttpStatus, Inject, Param, Post, Query } from '@nestjs/common'; +import { + Body, + Controller, + Delete, + Get, + HttpCode, + HttpException, + HttpStatus, + Inject, + NotFoundException, + Param, + Post, + Query, + UnauthorizedException, +} from '@nestjs/common'; import { IsIn, IsNotEmpty, IsOptional, IsString } from 'class-validator'; import type { MemoryScope } from '../../nodes/memory/memory.types'; import { MemoryService } from '../../nodes/memory/memory.service'; +import { AgentsPersistenceService } from '../../agents/agents.persistence.service'; +import { CurrentPrincipal } from '../../auth/principal.decorator'; +import type { Principal } from '../../auth/auth.types'; class DocParamsDto { @IsString() @@ -70,7 +87,17 @@ class EnsureDirBodyDto extends ThreadAwareDto { @Controller('api/memory') export class MemoryController { - constructor(@Inject(MemoryService) private readonly memoryService: MemoryService) {} + constructor( + @Inject(MemoryService) private readonly memoryService: MemoryService, + @Inject(AgentsPersistenceService) private readonly persistence: AgentsPersistenceService, + ) {} + + private requirePrincipal(principal: Principal | null): Principal { + if (!principal) { + throw new UnauthorizedException({ error: 'unauthorized' }); + } + return principal; + } private resolveThreadId(scope: MemoryScope, ...candidates: Array): string | undefined { if (scope !== 'perThread') return undefined; @@ -83,34 +110,93 @@ export class MemoryController { throw new HttpException({ error: 'threadId required for perThread scope' }, HttpStatus.BAD_REQUEST); } + private async resolveAuthorizedThreadId( + scope: MemoryScope, + ownerUserId: string, + ...candidates: Array + ): Promise { + if (scope !== 'perThread') return undefined; + const threadId = this.resolveThreadId(scope, ...candidates); + if (!threadId) { + throw new HttpException({ error: 'threadId required for perThread scope' }, HttpStatus.BAD_REQUEST); + } + const thread = await this.persistence.getThreadById(threadId, { ownerUserId }); + if (!thread) { + throw new NotFoundException({ error: 'thread_not_found' }); + } + return threadId; + } + + private async filterDocsForPrincipal( + items: Array<{ nodeId: string; scope: MemoryScope; threadId?: string }>, + ownerUserId: string, + ): Promise> { + const cache = new Map(); + const filtered: Array<{ nodeId: string; scope: MemoryScope; threadId?: string }> = []; + for (const item of items) { + if (!item.threadId) { + filtered.push(item); + continue; + } + if (!cache.has(item.threadId)) { + const thread = await this.persistence.getThreadById(item.threadId, { ownerUserId }); + cache.set(item.threadId, !!thread); + } + if (cache.get(item.threadId)) { + filtered.push(item); + } + } + return filtered; + } + @Get('docs') - async listDocs(): Promise<{ items: Array<{ nodeId: string; scope: MemoryScope; threadId?: string }> }> { + async listDocs( + @CurrentPrincipal() principal: Principal | null, + ): Promise<{ items: Array<{ nodeId: string; scope: MemoryScope; threadId?: string }> }> { + const currentPrincipal = this.requirePrincipal(principal); + const ownerUserId = currentPrincipal.userId; const items = await this.memoryService.listDocs(); - return { items }; + const filtered = await this.filterDocsForPrincipal(items, ownerUserId); + return { items: filtered }; } @Get(':nodeId/:scope/list') - async list(@Param() params: DocParamsDto, @Query() query: PathWithThreadQueryDto): Promise<{ items: Array<{ name: string; hasSubdocs: boolean }> }> { + async list( + @Param() params: DocParamsDto, + @Query() query: PathWithThreadQueryDto, + @CurrentPrincipal() principal: Principal | null, + ): Promise<{ items: Array<{ name: string; hasSubdocs: boolean }> }> { + const currentPrincipal = this.requirePrincipal(principal); const { nodeId, scope } = params; const path = query.path ?? '/'; - const threadId = this.resolveThreadId(scope, params.threadId, query.threadId); + const threadId = await this.resolveAuthorizedThreadId(scope, currentPrincipal.userId, params.threadId, query.threadId); const items = await this.memoryService.list(nodeId, scope, threadId, path || '/'); return { items }; } @Get(':nodeId/:scope/stat') - async stat(@Param() params: DocParamsDto, @Query() query: PathWithThreadQueryDto): Promise<{ exists: boolean; hasSubdocs: boolean; contentLength: number }> { + async stat( + @Param() params: DocParamsDto, + @Query() query: PathWithThreadQueryDto, + @CurrentPrincipal() principal: Principal | null, + ): Promise<{ exists: boolean; hasSubdocs: boolean; contentLength: number }> { + const currentPrincipal = this.requirePrincipal(principal); const { nodeId, scope } = params; const path = query.path; - const threadId = this.resolveThreadId(scope, params.threadId, query.threadId); + const threadId = await this.resolveAuthorizedThreadId(scope, currentPrincipal.userId, params.threadId, query.threadId); return this.memoryService.stat(nodeId, scope, threadId, path); } @Get(':nodeId/:scope/read') - async read(@Param() params: DocParamsDto, @Query() query: PathWithThreadQueryDto): Promise<{ content: string }> { + async read( + @Param() params: DocParamsDto, + @Query() query: PathWithThreadQueryDto, + @CurrentPrincipal() principal: Principal | null, + ): Promise<{ content: string }> { + const currentPrincipal = this.requirePrincipal(principal); const { nodeId, scope } = params; const path = query.path; - const threadId = this.resolveThreadId(scope, params.threadId, query.threadId); + const threadId = await this.resolveAuthorizedThreadId(scope, currentPrincipal.userId, params.threadId, query.threadId); try { const content = await this.memoryService.read(nodeId, scope, threadId, path); return { content }; @@ -123,16 +209,28 @@ export class MemoryController { @Post(':nodeId/:scope/append') @HttpCode(204) - async append(@Param() params: DocParamsDto, @Body() body: AppendBodyDto, @Query() query: ThreadOnlyQueryDto): Promise { + async append( + @Param() params: DocParamsDto, + @Body() body: AppendBodyDto, + @Query() query: ThreadOnlyQueryDto, + @CurrentPrincipal() principal: Principal | null, + ): Promise { + const currentPrincipal = this.requirePrincipal(principal); const { nodeId, scope } = params; - const threadId = this.resolveThreadId(scope, params.threadId, body.threadId, query.threadId); + const threadId = await this.resolveAuthorizedThreadId(scope, currentPrincipal.userId, params.threadId, body.threadId, query.threadId); await this.memoryService.append(nodeId, scope, threadId, body.path, body.data); } @Post(':nodeId/:scope/update') - async update(@Param() params: DocParamsDto, @Body() body: UpdateBodyDto, @Query() query: ThreadOnlyQueryDto): Promise<{ replaced: number }> { + async update( + @Param() params: DocParamsDto, + @Body() body: UpdateBodyDto, + @Query() query: ThreadOnlyQueryDto, + @CurrentPrincipal() principal: Principal | null, + ): Promise<{ replaced: number }> { + const currentPrincipal = this.requirePrincipal(principal); const { nodeId, scope } = params; - const threadId = this.resolveThreadId(scope, params.threadId, body.threadId, query.threadId); + const threadId = await this.resolveAuthorizedThreadId(scope, currentPrincipal.userId, params.threadId, body.threadId, query.threadId); try { const replaced = await this.memoryService.update(nodeId, scope, threadId, body.path, body.oldStr, body.newStr); return { replaced }; @@ -145,23 +243,39 @@ export class MemoryController { @Post(':nodeId/:scope/ensure-dir') @HttpCode(204) - async ensureDir(@Param() params: DocParamsDto, @Body() body: EnsureDirBodyDto, @Query() query: ThreadOnlyQueryDto): Promise { + async ensureDir( + @Param() params: DocParamsDto, + @Body() body: EnsureDirBodyDto, + @Query() query: ThreadOnlyQueryDto, + @CurrentPrincipal() principal: Principal | null, + ): Promise { + const currentPrincipal = this.requirePrincipal(principal); const { nodeId, scope } = params; - const threadId = this.resolveThreadId(scope, params.threadId, body.threadId, query.threadId); + const threadId = await this.resolveAuthorizedThreadId(scope, currentPrincipal.userId, params.threadId, body.threadId, query.threadId); await this.memoryService.ensureDir(nodeId, scope, threadId, body.path); } @Delete(':nodeId/:scope') - async remove(@Param() params: DocParamsDto, @Query() query: PathWithThreadQueryDto): Promise<{ removed: number }> { + async remove( + @Param() params: DocParamsDto, + @Query() query: PathWithThreadQueryDto, + @CurrentPrincipal() principal: Principal | null, + ): Promise<{ removed: number }> { + const currentPrincipal = this.requirePrincipal(principal); const { nodeId, scope } = params; - const threadId = this.resolveThreadId(scope, params.threadId, query.threadId); + const threadId = await this.resolveAuthorizedThreadId(scope, currentPrincipal.userId, params.threadId, query.threadId); return this.memoryService.delete(nodeId, scope, threadId, query.path); } @Get(':nodeId/:scope/dump') - async dump(@Param() params: DocParamsDto, @Query() query: ThreadOnlyQueryDto): Promise { + async dump( + @Param() params: DocParamsDto, + @Query() query: ThreadOnlyQueryDto, + @CurrentPrincipal() principal: Principal | null, + ): Promise { + const currentPrincipal = this.requirePrincipal(principal); const { nodeId, scope } = params; - const threadId = this.resolveThreadId(scope, params.threadId, query.threadId); + const threadId = await this.resolveAuthorizedThreadId(scope, currentPrincipal.userId, params.threadId, query.threadId); return this.memoryService.dump(nodeId, scope, threadId); } } diff --git a/packages/platform-server/src/index.ts b/packages/platform-server/src/index.ts index 06398476c..b20223fd1 100644 --- a/packages/platform-server/src/index.ts +++ b/packages/platform-server/src/index.ts @@ -51,7 +51,7 @@ async function bootstrap() { origin: allowedOrigins.length ? allowedOrigins : true, methods: ['GET', 'HEAD', 'PUT', 'PATCH', 'POST', 'DELETE', 'OPTIONS'], allowedHeaders: ['Content-Type', 'Authorization', 'X-Requested-With', 'Accept'], - credentials: false, + credentials: true, }; // Enable CORS via Nest to avoid Fastify type-provider generic mismatches diff --git a/packages/platform-ui/src/api/http.ts b/packages/platform-ui/src/api/http.ts index 705b3d878..4aec1f9c2 100644 --- a/packages/platform-ui/src/api/http.ts +++ b/packages/platform-ui/src/api/http.ts @@ -8,7 +8,7 @@ function createHttp(baseURL: string): AxiosInstance { const inst = axios.create({ baseURL, headers: { 'Content-Type': 'application/json', Accept: 'application/json' }, - withCredentials: false, + withCredentials: true, }); // Response: unwrap data; error: normalize to AxiosError with server message if present diff --git a/packages/platform-ui/src/api/modules/auth.ts b/packages/platform-ui/src/api/modules/auth.ts new file mode 100644 index 000000000..c3b7b0424 --- /dev/null +++ b/packages/platform-ui/src/api/modules/auth.ts @@ -0,0 +1,21 @@ +import { asData, http } from '../http'; + +export type AuthMode = 'single_user' | 'oidc'; + +export type AuthStatusResponse = { + mode: AuthMode; + authenticated: boolean; + user: { + id: string; + email: string | null; + name: string | null; + } | null; +}; + +export async function getAuthStatus(): Promise { + return asData(http.get('/api/auth/status')); +} + +export async function logout(): Promise { + await http.post('/api/auth/logout'); +} diff --git a/packages/platform-ui/src/components/Sidebar.tsx b/packages/platform-ui/src/components/Sidebar.tsx index 6de3cab42..e0fba3d2b 100644 --- a/packages/platform-ui/src/components/Sidebar.tsx +++ b/packages/platform-ui/src/components/Sidebar.tsx @@ -20,19 +20,21 @@ export interface SubMenuItem { interface SidebarProps { menuItems: MenuItem[]; currentUser?: { - name: string; - email: string; - avatar?: string; + name: string | null; + email: string | null; + avatarUrl?: string | null; }; selectedMenuItem?: string; onMenuItemSelect?: (itemId: string) => void; + onLogout?: () => void; } export default function Sidebar({ menuItems, - currentUser = { name: 'John Doe', email: 'john@agyn.io' }, + currentUser = { name: 'John Doe', email: 'john@agyn.io', avatarUrl: null }, selectedMenuItem = 'graph', - onMenuItemSelect + onMenuItemSelect, + onLogout, }: SidebarProps) { const [expandedItems, setExpandedItems] = useState>(new Set(['agents'])); @@ -133,20 +135,29 @@ export default function Sidebar({ {/* User Footer */} -
+
- {currentUser.avatar ? ( - {currentUser.name} + {currentUser.avatarUrl ? ( + {currentUser.name ) : ( )}
-

{currentUser.name}

-

{currentUser.email}

+

{currentUser.name ?? currentUser.email ?? 'Agyn User'}

+

{currentUser.email ?? '—'}

+ {onLogout ? ( + + ) : null}
); diff --git a/packages/platform-ui/src/components/layouts/MainLayout.tsx b/packages/platform-ui/src/components/layouts/MainLayout.tsx index eefcfc081..04187065f 100644 --- a/packages/platform-ui/src/components/layouts/MainLayout.tsx +++ b/packages/platform-ui/src/components/layouts/MainLayout.tsx @@ -6,6 +6,12 @@ interface MainLayoutProps { menuItems: MenuItem[]; selectedMenuItem?: string; onMenuItemSelect?: (itemId: string) => void; + currentUser?: { + name: string | null; + email: string | null; + avatarUrl?: string | null; + }; + onLogout?: () => void; } export function MainLayout({ @@ -13,6 +19,8 @@ export function MainLayout({ menuItems, selectedMenuItem, onMenuItemSelect, + currentUser, + onLogout, }: MainLayoutProps) { return (
@@ -20,7 +28,9 @@ export function MainLayout({
{children} diff --git a/packages/platform-ui/src/layout/RootLayout.tsx b/packages/platform-ui/src/layout/RootLayout.tsx index 4d9073da7..370cba4ac 100644 --- a/packages/platform-ui/src/layout/RootLayout.tsx +++ b/packages/platform-ui/src/layout/RootLayout.tsx @@ -16,6 +16,7 @@ import { } from 'lucide-react'; import { MainLayout } from '../components/layouts/MainLayout'; import type { MenuItem } from '../components/Sidebar'; +import { useUser } from '@/user/user.runtime'; const MENU_ITEM_ROUTES: Record = { graph: '/agents/graph', @@ -73,6 +74,7 @@ function getMenuItemFromPath(pathname: string) { export function RootLayout() { const location = useLocation(); const navigate = useNavigate(); + const { user, mode, authenticated, logout } = useUser(); const selectedMenuItem = getMenuItemFromPath(location.pathname); @@ -88,11 +90,18 @@ export function RootLayout() { [location.pathname, navigate], ); + const canLogout = mode === 'oidc' && authenticated; + const handleLogout = useCallback(() => { + void logout(); + }, [logout]); + return ( diff --git a/packages/platform-ui/src/lib/graph/socket.ts b/packages/platform-ui/src/lib/graph/socket.ts index f00ffa60b..ed0a61418 100644 --- a/packages/platform-ui/src/lib/graph/socket.ts +++ b/packages/platform-ui/src/lib/graph/socket.ts @@ -121,7 +121,7 @@ class GraphSocket { reconnectionAttempts: Infinity, reconnectionDelay: 1000, reconnectionDelayMax: 5000, - withCredentials: false, + withCredentials: true, }; this.socketCleanup = []; this.managerCleanup = []; diff --git a/packages/platform-ui/src/user/UserProvider.tsx b/packages/platform-ui/src/user/UserProvider.tsx index 8c3b62232..b87274a5d 100644 --- a/packages/platform-ui/src/user/UserProvider.tsx +++ b/packages/platform-ui/src/user/UserProvider.tsx @@ -1,8 +1,189 @@ -import React from 'react'; -import type { UserContextType } from './user-types'; +import { useCallback, useEffect, useMemo, useState, type ReactNode } from 'react'; +import { config } from '@/config'; +import * as authApi from '@/api/modules/auth'; +import type { AuthStatusResponse } from '@/api/modules/auth'; +import type { User, UserContextType } from './user-types'; import { UserContext } from './user.runtime'; -export function UserProvider({ children }: { children: React.ReactNode }) { - const value: UserContextType = { user: { name: 'Casey Quinn', email: 'casey@example.com' } }; - return {children}; +type AuthState = { + user: User | null; + authenticated: boolean; + mode: authApi.AuthMode; + loading: boolean; + error: string | null; +}; + +const initialState: AuthState = { + user: null, + authenticated: false, + mode: 'single_user', + loading: true, + error: null, +}; + +function mapUser(payload: AuthStatusResponse['user']): User | null { + if (!payload) return null; + return { + id: payload.id, + email: payload.email ?? null, + name: payload.name ?? payload.email ?? null, + avatarUrl: null, + }; +} + +type SplashProps = { + title: string; + description?: string; + primaryLabel?: string; + onPrimary?: () => void; + secondaryLabel?: string; + onSecondary?: () => void; + error?: string | null; +}; + +function AuthSplash({ title, description, primaryLabel, onPrimary, secondaryLabel, onSecondary, error }: SplashProps) { + return ( +
+
+
{title}
+ {description ?

{description}

: null} + {error ?

{error}

: null} +
+ {primaryLabel ? ( + + ) : null} + {secondaryLabel ? ( + + ) : null} +
+
+
+ ); +} + +function createBypassValue(): UserContextType { + return { + user: { + id: 'mock-user', + name: 'Agyn User', + email: 'user@example.com', + avatarUrl: null, + }, + authenticated: true, + mode: 'single_user', + loading: false, + error: null, + login: () => {}, + logout: async () => {}, + refresh: async () => {}, + }; +} + +function AuthenticatedUserProvider({ children }: { children: ReactNode }) { + const [state, setState] = useState(initialState); + + const refresh = useCallback(async () => { + setState((prev) => ({ ...prev, loading: true, error: null })); + try { + const status = await authApi.getAuthStatus(); + setState({ + user: mapUser(status.user), + authenticated: status.authenticated, + mode: status.mode, + loading: false, + error: null, + }); + } catch (error) { + const message = error instanceof Error ? error.message : 'Unable to reach authentication service'; + setState((prev) => ({ ...prev, loading: false, error: message })); + } + }, []); + + useEffect(() => { + void refresh(); + }, [refresh]); + + const login = useCallback(() => { + const url = `${config.apiBaseUrl}/api/auth/login`; + window.location.assign(url); + }, []); + + const logout = useCallback(async () => { + try { + await authApi.logout(); + } finally { + await refresh(); + } + }, [refresh]); + + const contextValue = useMemo( + () => ({ + user: state.user, + authenticated: state.authenticated, + mode: state.mode, + loading: state.loading, + error: state.error, + login, + logout, + refresh, + }), + [state, login, logout, refresh], + ); + + if (state.loading) { + return ( + + ); + } + + if (state.error && !state.authenticated) { + return ( + + ); + } + + if (state.mode === 'oidc' && !state.authenticated) { + return ( + + ); + } + + return {children}; +} + +export function UserProvider({ children }: { children: ReactNode }) { + const env = typeof import.meta !== 'undefined' ? import.meta.env : undefined; + const bypassAuth = env?.MODE === 'test' || env?.STORYBOOK === 'true'; + + if (bypassAuth) { + const value = createBypassValue(); + return {children}; + } + + return {children}; } diff --git a/packages/platform-ui/src/user/user-types.ts b/packages/platform-ui/src/user/user-types.ts index 45a9e7acd..7b74ae320 100644 --- a/packages/platform-ui/src/user/user-types.ts +++ b/packages/platform-ui/src/user/user-types.ts @@ -1,8 +1,19 @@ +export type AuthMode = 'single_user' | 'oidc'; + export type User = { - name: string; - email: string; - avatarUrl?: string; + id: string; + name: string | null; + email: string | null; + avatarUrl?: string | null; }; -export type UserContextType = { user: User | null }; - +export type UserContextType = { + user: User | null; + authenticated: boolean; + mode: AuthMode; + loading: boolean; + error: string | null; + login: () => void; + logout: () => Promise; + refresh: () => Promise; +}; diff --git a/packages/platform-ui/src/user/user.runtime.ts b/packages/platform-ui/src/user/user.runtime.ts index 5c879b250..4acc8a74c 100644 --- a/packages/platform-ui/src/user/user.runtime.ts +++ b/packages/platform-ui/src/user/user.runtime.ts @@ -2,8 +2,20 @@ import React from 'react'; import type { UserContextType } from './user-types'; // Runtime-only context container; no components are exported here. -export const UserContext = React.createContext({ user: null }); +const noop = () => {}; +const noopAsync = async () => {}; + +export const UserContext = React.createContext({ + user: null, + authenticated: false, + mode: 'single_user', + loading: true, + error: null, + login: noop, + logout: noopAsync, + refresh: noopAsync, +}); + export function useUser() { return React.useContext(UserContext); } - diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index b955edc0d..3b84a83ec 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -147,6 +147,9 @@ importers: class-validator: specifier: ^0.14.1 version: 0.14.2 + cookie: + specifier: ^1.1.1 + version: 1.1.1 dockerode: specifier: ^4.0.8 version: 4.0.8 @@ -177,6 +180,9 @@ importers: openai: specifier: ^6.6.0 version: 6.6.0(ws@8.18.3)(zod@4.1.12) + openid-client: + specifier: ^6.8.1 + version: 6.8.1 p-limit: specifier: ^3.1.0 version: 3.1.0 @@ -4627,8 +4633,8 @@ packages: resolution: {integrity: sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==} engines: {node: '>= 0.6'} - cookie@1.0.2: - resolution: {integrity: sha512-9Kr/j4O16ISv8zBBhJoi4bXOYNTkFLOqSL3UDB0njXxCXNezjeyVrJyGOWtgfs/q2km1gwBcfH8q1yEGoMYunA==} + cookie@1.1.1: + resolution: {integrity: sha512-ei8Aos7ja0weRpFzJnEA9UHJ/7XQmqglbRwnf2ATjcB9Wq874VKH9kfjjirM6UhU2/E5fFYadylyhFldcqSidQ==} engines: {node: '>=18'} copy-anything@3.0.5: @@ -6100,6 +6106,9 @@ packages: joi@17.13.3: resolution: {integrity: sha512-otDA4ldcIx+ZXsKHWmp0YizCweVRZG96J10b0FevjfuncLO1oX59THoAmHkNubYJ+9gWsYsp5k8v4ib6oDv1fA==} + jose@6.1.3: + resolution: {integrity: sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ==} + js-tiktoken@1.0.21: resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==} @@ -6700,6 +6709,9 @@ packages: engines: {node: ^14.16.0 || >=16.10.0} hasBin: true + oauth4webapi@3.8.3: + resolution: {integrity: sha512-pQ5BsX3QRTgnt5HxgHwgunIRaDXBdkT23tf8dfzmtTIL2LTpdmxgbpbBm0VgFWAIDlezQvQCTgnVIUmHupXHxw==} + object-assign@4.1.1: resolution: {integrity: sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==} engines: {node: '>=0.10.0'} @@ -6757,6 +6769,9 @@ packages: zod: optional: true + openid-client@6.8.1: + resolution: {integrity: sha512-VoYT6enBo6Vj2j3Q5Ec0AezS+9YGzQo1f5Xc42lreMGlfP4ljiXPKVDvCADh+XHCV/bqPu/wWSiCVXbJKvrODw==} + optionator@0.9.4: resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==} engines: {node: '>= 0.8.0'} @@ -7805,6 +7820,7 @@ packages: tar@7.4.3: resolution: {integrity: sha512-5S7Va8hKfV7W5U6g3aYxXmlPoZVAwUMy9AOKyF2fVuZa2UD3qZjg578OrLRt8PcNN1PleVaL/5/yYATNL0ICUw==} engines: {node: '>=18'} + deprecated: Old versions of tar are not supported, and contain widely publicized security vulnerabilities, which have been fixed in the current version. Please update. Support for old versions may be purchased (at exhorbitant rates) by contacting i@izs.me test-exclude@6.0.0: resolution: {integrity: sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w==} @@ -8241,6 +8257,7 @@ packages: whatwg-encoding@3.1.1: resolution: {integrity: sha512-6qN4hJdMwfYBtE3YBTTHhoeuUrDBPZmbQaxWAqSALV/MeEnR5z1xd8UKud2RAkFoPkmB+hli1TZSnyi84xz1vQ==} engines: {node: '>=18'} + deprecated: Use @exodus/bytes instead for a more spec-conformant and faster implementation whatwg-mimetype@4.0.0: resolution: {integrity: sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==} @@ -12220,15 +12237,6 @@ snapshots: msw: 2.11.3(@types/node@20.19.19)(typescript@5.9.2) vite: 7.1.6(@types/node@20.19.19)(jiti@2.5.1)(lightningcss@1.30.1)(tsx@4.20.5)(yaml@2.8.1) - '@vitest/mocker@3.2.4(msw@2.11.3(@types/node@24.5.2)(typescript@5.8.3))(vite@7.1.6(@types/node@20.19.19)(jiti@2.5.1)(lightningcss@1.30.1)(tsx@4.20.5)(yaml@2.8.1))': - dependencies: - '@vitest/spy': 3.2.4 - estree-walker: 3.0.3 - magic-string: 0.30.19 - optionalDependencies: - msw: 2.11.3(@types/node@24.5.2)(typescript@5.8.3) - vite: 7.1.6(@types/node@20.19.19)(jiti@2.5.1)(lightningcss@1.30.1)(tsx@4.20.5)(yaml@2.8.1) - '@vitest/mocker@3.2.4(msw@2.11.3(@types/node@24.5.2)(typescript@5.8.3))(vite@7.1.6(@types/node@24.5.2)(jiti@2.5.1)(lightningcss@1.30.1)(tsx@4.20.5)(yaml@2.8.1))': dependencies: '@vitest/spy': 3.2.4 @@ -12815,7 +12823,7 @@ snapshots: cookie@0.7.2: {} - cookie@1.0.2: {} + cookie@1.1.1: {} copy-anything@3.0.5: dependencies: @@ -14639,6 +14647,8 @@ snapshots: '@sideway/formula': 3.0.1 '@sideway/pinpoint': 2.0.0 + jose@6.1.3: {} + js-tiktoken@1.0.21: dependencies: base64-js: 1.5.1 @@ -14752,7 +14762,7 @@ snapshots: light-my-request@6.6.0: dependencies: - cookie: 1.0.2 + cookie: 1.1.1 process-warning: 4.0.1 set-cookie-parser: 2.7.1 @@ -15457,6 +15467,8 @@ snapshots: pkg-types: 2.3.0 tinyexec: 1.0.1 + oauth4webapi@3.8.3: {} + object-assign@4.1.1: {} object-inspect@1.13.4: {} @@ -15509,6 +15521,11 @@ snapshots: ws: 8.18.3 zod: 4.1.12 + openid-client@6.8.1: + dependencies: + jose: 6.1.3 + oauth4webapi: 3.8.3 + optionator@0.9.4: dependencies: deep-is: 0.1.4 @@ -17187,7 +17204,7 @@ snapshots: dependencies: '@types/chai': 5.2.2 '@vitest/expect': 3.2.4 - '@vitest/mocker': 3.2.4(msw@2.11.3(@types/node@24.5.2)(typescript@5.8.3))(vite@7.1.6(@types/node@20.19.19)(jiti@2.5.1)(lightningcss@1.30.1)(tsx@4.20.5)(yaml@2.8.1)) + '@vitest/mocker': 3.2.4(msw@2.11.3(@types/node@24.5.2)(typescript@5.8.3))(vite@7.1.6(@types/node@24.5.2)(jiti@2.5.1)(lightningcss@1.30.1)(tsx@4.20.5)(yaml@2.8.1)) '@vitest/pretty-format': 3.2.4 '@vitest/runner': 3.2.4 '@vitest/snapshot': 3.2.4