From 8ea66e7df9e0f081212ad4fd37d1e6cce49c6a06 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 11 Nov 2024 22:05:13 +0000 Subject: [PATCH 01/11] Tests for protocol version negotiation Resolves #44. --- src/client/index.test.ts | 125 ++++++++++++++++++++++++++++- src/server/index.test.ts | 166 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 285 insertions(+), 6 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index d93ca39..a6dcc73 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -3,11 +3,130 @@ /* eslint-disable @typescript-eslint/no-unused-expressions */ import { Client } from "./index.js"; import { z } from "zod"; -import { RequestSchema, NotificationSchema, ResultSchema } from "../types.js"; +import { + RequestSchema, + NotificationSchema, + ResultSchema, + LATEST_PROTOCOL_VERSION, + SUPPORTED_PROTOCOL_VERSIONS, +} from "../types.js"; +import { Transport } from "../shared/transport.js"; + +test("should initialize with matching protocol version", async () => { + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.method === "initialize") { + clientTransport.onmessage?.({ + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + serverInfo: { + name: "test", + version: "1.0", + }, + }, + }); + } + return Promise.resolve(); + }), + }; + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + await client.connect(clientTransport); + + // Should have sent initialize with latest version + expect(clientTransport.send).toHaveBeenCalledWith( + expect.objectContaining({ + method: "initialize", + params: expect.objectContaining({ + protocolVersion: LATEST_PROTOCOL_VERSION, + }), + }), + ); +}); + +test("should initialize with supported older protocol version", async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.method === "initialize") { + clientTransport.onmessage?.({ + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: OLD_VERSION, + capabilities: {}, + serverInfo: { + name: "test", + version: "1.0", + }, + }, + }); + } + return Promise.resolve(); + }), + }; + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + await client.connect(clientTransport); + + // Connection should succeed with the older version + expect(client.getServerVersion()).toEqual({ + name: "test", + version: "1.0", + }); +}); + +test("should reject unsupported protocol version", async () => { + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.method === "initialize") { + clientTransport.onmessage?.({ + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: "invalid-version", + capabilities: {}, + serverInfo: { + name: "test", + version: "1.0", + }, + }, + }); + } + return Promise.resolve(); + }), + }; + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + await expect(client.connect(clientTransport)).rejects.toThrow( + "Server's protocol version is not supported: invalid-version", + ); +}); /* -Test that custom request/notification/result schemas can be used with the Client class. -*/ + Test that custom request/notification/result schemas can be used with the Client class. + */ test("should typecheck", () => { const GetWeatherRequestSchema = RequestSchema.extend({ method: z.literal("weather/get"), diff --git a/src/server/index.test.ts b/src/server/index.test.ts index be33c58..5045e2a 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -3,11 +3,171 @@ /* eslint-disable @typescript-eslint/no-unused-expressions */ import { Server } from "./index.js"; import { z } from "zod"; -import { RequestSchema, NotificationSchema, ResultSchema } from "../types.js"; +import { + RequestSchema, + NotificationSchema, + ResultSchema, + LATEST_PROTOCOL_VERSION, + SUPPORTED_PROTOCOL_VERSIONS, + InitializeRequestSchema, + InitializeResultSchema, +} from "../types.js"; +import { Transport } from "../shared/transport.js"; + +test("should accept latest protocol version", async () => { + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise((resolve) => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: "test server", + version: "1.0", + }, + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }), + }; + + const server = new Server({ + name: "test server", + version: "1.0", + }); + + await server.connect(serverTransport); + + // Simulate initialize request with latest version + serverTransport.onmessage?.({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + clientInfo: { + name: "test client", + version: "1.0", + }, + }, + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); + +test("should accept supported older protocol version", async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise((resolve) => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: OLD_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: "test server", + version: "1.0", + }, + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }), + }; + + const server = new Server({ + name: "test server", + version: "1.0", + }); + + await server.connect(serverTransport); + + // Simulate initialize request with older version + serverTransport.onmessage?.({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: OLD_VERSION, + capabilities: {}, + clientInfo: { + name: "test client", + version: "1.0", + }, + }, + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); + +test("should handle unsupported protocol version", async () => { + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise((resolve) => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: "test server", + version: "1.0", + }, + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }), + }; + + const server = new Server({ + name: "test server", + version: "1.0", + }); + + await server.connect(serverTransport); + + // Simulate initialize request with unsupported version + serverTransport.onmessage?.({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "invalid-version", + capabilities: {}, + clientInfo: { + name: "test client", + version: "1.0", + }, + }, + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); /* -Test that custom request/notification/result schemas can be used with the Server class. -*/ + Test that custom request/notification/result schemas can be used with the Server class. + */ test("should typecheck", () => { const GetWeatherRequestSchema = RequestSchema.extend({ method: z.literal("weather/get"), From c82d633d8d3589f52090caf4ca06f148571afe91 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 11 Nov 2024 22:08:46 +0000 Subject: [PATCH 02/11] Disconnect from client side if initialization fails Resolves #24. --- src/client/index.test.ts | 2 ++ src/client/index.ts | 62 +++++++++++++++++++++++----------------- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index a6dcc73..86b12ed 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -122,6 +122,8 @@ test("should reject unsupported protocol version", async () => { await expect(client.connect(clientTransport)).rejects.toThrow( "Server's protocol version is not supported: invalid-version", ); + + expect(clientTransport.close).toHaveBeenCalled(); }); /* diff --git a/src/client/index.ts b/src/client/index.ts index c3662b2..3b194f2 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -32,7 +32,7 @@ import { ServerCapabilities, SubscribeRequest, SUPPORTED_PROTOCOL_VERSIONS, - UnsubscribeRequest + UnsubscribeRequest, } from "../types.js"; /** @@ -82,34 +82,40 @@ export class Client< override async connect(transport: Transport): Promise { await super.connect(transport); - const result = await this.request( - { - method: "initialize", - params: { - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: {}, - clientInfo: this._clientInfo, + try { + const result = await this.request( + { + method: "initialize", + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + clientInfo: this._clientInfo, + }, }, - }, - InitializeResultSchema, - ); - - if (result === undefined) { - throw new Error(`Server sent invalid initialize result: ${result}`); - } - - if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) { - throw new Error( - `Server's protocol version is not supported: ${result.protocolVersion}`, + InitializeResultSchema, ); - } - - this._serverCapabilities = result.capabilities; - this._serverVersion = result.serverInfo; - await this.notification({ - method: "notifications/initialized", - }); + if (result === undefined) { + throw new Error(`Server sent invalid initialize result: ${result}`); + } + + if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) { + throw new Error( + `Server's protocol version is not supported: ${result.protocolVersion}`, + ); + } + + this._serverCapabilities = result.capabilities; + this._serverVersion = result.serverInfo; + + await this.notification({ + method: "notifications/initialized", + }); + } catch (error) { + // Disconnect if initialization fails. + void this.close(); + throw error; + } } /** @@ -219,7 +225,9 @@ export class Client< async callTool( params: CallToolRequest["params"], - resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, + resultSchema: + | typeof CallToolResultSchema + | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, onprogress?: ProgressCallback, ) { return this.request( From 852ebb9ec0bbfb7c58664a0373f111e167ed25e0 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 11 Nov 2024 22:27:45 +0000 Subject: [PATCH 03/11] Add an InMemoryTransport --- src/inMemory.test.ts | 94 ++++++++++++++++++++++++++++++++++++++++++++ src/inMemory.ts | 54 +++++++++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 src/inMemory.test.ts create mode 100644 src/inMemory.ts diff --git a/src/inMemory.test.ts b/src/inMemory.test.ts new file mode 100644 index 0000000..f7e9e97 --- /dev/null +++ b/src/inMemory.test.ts @@ -0,0 +1,94 @@ +import { InMemoryTransport } from "./inMemory.js"; +import { JSONRPCMessage } from "./types.js"; + +describe("InMemoryTransport", () => { + let clientTransport: InMemoryTransport; + let serverTransport: InMemoryTransport; + + beforeEach(() => { + [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + }); + + test("should create linked pair", () => { + expect(clientTransport).toBeDefined(); + expect(serverTransport).toBeDefined(); + }); + + test("should start without error", async () => { + await expect(clientTransport.start()).resolves.not.toThrow(); + await expect(serverTransport.start()).resolves.not.toThrow(); + }); + + test("should send message from client to server", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + id: 1, + }; + + let receivedMessage: JSONRPCMessage | undefined; + serverTransport.onmessage = (msg) => { + receivedMessage = msg; + }; + + await clientTransport.send(message); + expect(receivedMessage).toEqual(message); + }); + + test("should send message from server to client", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + id: 1, + }; + + let receivedMessage: JSONRPCMessage | undefined; + clientTransport.onmessage = (msg) => { + receivedMessage = msg; + }; + + await serverTransport.send(message); + expect(receivedMessage).toEqual(message); + }); + + test("should handle close", async () => { + let clientClosed = false; + let serverClosed = false; + + clientTransport.onclose = () => { + clientClosed = true; + }; + + serverTransport.onclose = () => { + serverClosed = true; + }; + + await clientTransport.close(); + expect(clientClosed).toBe(true); + expect(serverClosed).toBe(true); + }); + + test("should throw error when sending after close", async () => { + await clientTransport.close(); + await expect( + clientTransport.send({ jsonrpc: "2.0", method: "test", id: 1 }), + ).rejects.toThrow("Not connected"); + }); + + test("should queue messages sent before start", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + id: 1, + }; + + let receivedMessage: JSONRPCMessage | undefined; + serverTransport.onmessage = (msg) => { + receivedMessage = msg; + }; + + await clientTransport.send(message); + await serverTransport.start(); + expect(receivedMessage).toEqual(message); + }); +}); diff --git a/src/inMemory.ts b/src/inMemory.ts new file mode 100644 index 0000000..2763f38 --- /dev/null +++ b/src/inMemory.ts @@ -0,0 +1,54 @@ +import { Transport } from "./shared/transport.js"; +import { JSONRPCMessage } from "./types.js"; + +/** + * In-memory transport for creating clients and servers that talk to each other within the same process. + */ +export class InMemoryTransport implements Transport { + private _otherTransport?: InMemoryTransport; + private _messageQueue: JSONRPCMessage[] = []; + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + /** + * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. + */ + static createLinkedPair(): [InMemoryTransport, InMemoryTransport] { + const clientTransport = new InMemoryTransport(); + const serverTransport = new InMemoryTransport(); + clientTransport._otherTransport = serverTransport; + serverTransport._otherTransport = clientTransport; + return [clientTransport, serverTransport]; + } + + async start(): Promise { + // Process any messages that were queued before start was called + while (this._messageQueue.length > 0) { + const message = this._messageQueue.shift(); + if (message) { + this.onmessage?.(message); + } + } + } + + async close(): Promise { + const other = this._otherTransport; + this._otherTransport = undefined; + await other?.close(); + this.onclose?.(); + } + + async send(message: JSONRPCMessage): Promise { + if (!this._otherTransport) { + throw new Error("Not connected"); + } + + if (this._otherTransport.onmessage) { + this._otherTransport.onmessage(message); + } else { + this._otherTransport._messageQueue.push(message); + } + } +} From 90e91cdd6223381f4cdfe0ba5521645b99bcc136 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 11 Nov 2024 22:30:50 +0000 Subject: [PATCH 04/11] Tests and assertions for client/server capabilities Resolves #45. --- src/client/index.test.ts | 60 ++++++++++++++++++++++++++++++++++++++++ src/client/index.ts | 22 +++++++++++++++ src/server/index.test.ts | 31 +++++++++++++++++++-- src/server/index.ts | 19 +++++++++++-- 4 files changed, 128 insertions(+), 4 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 86b12ed..6bcbf56 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -9,8 +9,13 @@ import { ResultSchema, LATEST_PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS, + InitializeRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; +import { Server } from "../server/index.js"; +import { InMemoryTransport } from "../inMemory.js"; test("should initialize with matching protocol version", async () => { const clientTransport: Transport = { @@ -126,6 +131,61 @@ test("should reject unsupported protocol version", async () => { expect(clientTransport.close).toHaveBeenCalled(); }); +test("should respect server capabilities", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + server.setRequestHandler(InitializeRequestSchema, (request) => ({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: { + resources: {}, + tools: {}, + }, + serverInfo: { + name: "test", + version: "1.0", + }, + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [], + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // Server supports resources and tools, but not prompts + expect(client.getServerCapabilities()).toEqual({ + resources: {}, + tools: {}, + }); + + // These should work + await expect(client.listResources()).resolves.not.toThrow(); + await expect(client.listTools()).resolves.not.toThrow(); + + // This should throw because prompts are not supported + await expect(client.listPrompts()).rejects.toThrow( + "Server does not support prompts", + ); +}); + /* Test that custom request/notification/result schemas can be used with the Client class. */ diff --git a/src/client/index.ts b/src/client/index.ts index 3b194f2..b2ead25 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -132,6 +132,17 @@ export class Client< return this._serverVersion; } + private assertCapability( + capability: keyof ServerCapabilities, + method: string, + ) { + if (!this._serverCapabilities?.[capability]) { + throw new Error( + `Server does not support ${capability} (required for ${method})`, + ); + } + } + async ping() { return this.request({ method: "ping" }, EmptyResultSchema); } @@ -140,6 +151,7 @@ export class Client< params: CompleteRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("prompts", "completion/complete"); return this.request( { method: "completion/complete", params }, CompleteResultSchema, @@ -148,6 +160,7 @@ export class Client< } async setLoggingLevel(level: LoggingLevel) { + this.assertCapability("logging", "logging/setLevel"); return this.request( { method: "logging/setLevel", params: { level } }, EmptyResultSchema, @@ -158,6 +171,7 @@ export class Client< params: GetPromptRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("prompts", "prompts/get"); return this.request( { method: "prompts/get", params }, GetPromptResultSchema, @@ -169,6 +183,7 @@ export class Client< params?: ListPromptsRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("prompts", "prompts/list"); return this.request( { method: "prompts/list", params }, ListPromptsResultSchema, @@ -180,6 +195,7 @@ export class Client< params?: ListResourcesRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("resources", "resources/list"); return this.request( { method: "resources/list", params }, ListResourcesResultSchema, @@ -191,6 +207,7 @@ export class Client< params?: ListResourceTemplatesRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("resources", "resources/templates/list"); return this.request( { method: "resources/templates/list", params }, ListResourceTemplatesResultSchema, @@ -202,6 +219,7 @@ export class Client< params: ReadResourceRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("resources", "resources/read"); return this.request( { method: "resources/read", params }, ReadResourceResultSchema, @@ -210,6 +228,7 @@ export class Client< } async subscribeResource(params: SubscribeRequest["params"]) { + this.assertCapability("resources", "resources/subscribe"); return this.request( { method: "resources/subscribe", params }, EmptyResultSchema, @@ -217,6 +236,7 @@ export class Client< } async unsubscribeResource(params: UnsubscribeRequest["params"]) { + this.assertCapability("resources", "resources/unsubscribe"); return this.request( { method: "resources/unsubscribe", params }, EmptyResultSchema, @@ -230,6 +250,7 @@ export class Client< | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, onprogress?: ProgressCallback, ) { + this.assertCapability("tools", "tools/call"); return this.request( { method: "tools/call", params }, resultSchema, @@ -241,6 +262,7 @@ export class Client< params?: ListToolsRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("tools", "tools/list"); return this.request( { method: "tools/list", params }, ListToolsResultSchema, diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 5045e2a..dd5c94a 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -9,10 +9,10 @@ import { ResultSchema, LATEST_PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS, - InitializeRequestSchema, - InitializeResultSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; +import { InMemoryTransport } from "../inMemory.js"; +import { Client } from "../client/index.js"; test("should accept latest protocol version", async () => { let sendPromiseResolve: (value: unknown) => void; @@ -165,6 +165,33 @@ test("should handle unsupported protocol version", async () => { await expect(sendPromise).resolves.toBeUndefined(); }); +test("should respect client capabilities", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + expect(server.getClientCapabilities()).toEqual({}); + + // This should throw because roots are not supported by the client + await expect(server.listRoots()).rejects.toThrow( + "Client does not support roots", + ); +}); + /* Test that custom request/notification/result schemas can be used with the Server class. */ diff --git a/src/server/index.ts b/src/server/index.ts index 8cf2a91..203d3bd 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -25,7 +25,7 @@ import { ServerRequest, ServerResult, SetLevelRequestSchema, - SUPPORTED_PROTOCOL_VERSIONS + SUPPORTED_PROTOCOL_VERSIONS, } from "../types.js"; /** @@ -93,7 +93,9 @@ export class Server< this._clientVersion = request.params.clientInfo; return { - protocolVersion: SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion) ? requestedVersion : LATEST_PROTOCOL_VERSION, + protocolVersion: SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion) + ? requestedVersion + : LATEST_PROTOCOL_VERSION, capabilities: this.getCapabilities(), serverInfo: this._serverInfo, }; @@ -138,6 +140,17 @@ export class Server< }; } + private assertCapability( + capability: keyof ClientCapabilities, + method: string, + ) { + if (!this._clientCapabilities?.[capability]) { + throw new Error( + `Client does not support ${capability} (required for ${method})`, + ); + } + } + async ping() { return this.request({ method: "ping" }, EmptyResultSchema); } @@ -146,6 +159,7 @@ export class Server< params: CreateMessageRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("sampling", "sampling/createMessage"); return this.request( { method: "sampling/createMessage", params }, CreateMessageResultSchema, @@ -157,6 +171,7 @@ export class Server< params?: ListRootsRequest["params"], onprogress?: ProgressCallback, ) { + this.assertCapability("roots", "roots/list"); return this.request( { method: "roots/list", params }, ListRootsResultSchema, From 7ab58805692841966d782677eeb943d9cb443025 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 11 Nov 2024 22:38:43 +0000 Subject: [PATCH 05/11] Use explicit capabilities lists for now --- src/cli.ts | 29 +++++++--- src/client/index.test.ts | 81 ++++++++++++++++++--------- src/client/index.ts | 8 ++- src/server/index.test.ts | 116 ++++++++++++++++++++++++++++++--------- src/server/index.ts | 28 ++-------- 5 files changed, 179 insertions(+), 83 deletions(-) diff --git a/src/cli.ts b/src/cli.ts index 2713d77..a7b6230 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -17,10 +17,15 @@ import { StdioServerTransport } from "./server/stdio.js"; import { ListResourcesResultSchema } from "./types.js"; async function runClient(url_or_command: string, args: string[]) { - const client = new Client({ - name: "mcp-typescript test client", - version: "0.1.0", - }); + const client = new Client( + { + name: "mcp-typescript test client", + version: "0.1.0", + }, + { + sampling: {}, + }, + ); let clientTransport; @@ -97,10 +102,18 @@ async function runServer(port: number | null) { console.log(`Server running on http://localhost:${port}/sse`); }); } else { - const server = new Server({ - name: "mcp-typescript test server", - version: "0.1.0", - }); + const server = new Server( + { + name: "mcp-typescript test server", + version: "0.1.0", + }, + { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + ); const transport = new StdioServerTransport(); await server.connect(transport); diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 6bcbf56..62e30b6 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -40,10 +40,15 @@ test("should initialize with matching protocol version", async () => { }), }; - const client = new Client({ - name: "test client", - version: "1.0", - }); + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + sampling: {}, + }, + ); await client.connect(clientTransport); @@ -82,10 +87,15 @@ test("should initialize with supported older protocol version", async () => { }), }; - const client = new Client({ - name: "test client", - version: "1.0", - }); + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + sampling: {}, + }, + ); await client.connect(clientTransport); @@ -119,10 +129,15 @@ test("should reject unsupported protocol version", async () => { }), }; - const client = new Client({ - name: "test client", - version: "1.0", - }); + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + sampling: {}, + }, + ); await expect(client.connect(clientTransport)).rejects.toThrow( "Server's protocol version is not supported: invalid-version", @@ -132,12 +147,18 @@ test("should reject unsupported protocol version", async () => { }); test("should respect server capabilities", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + resources: {}, + tools: {}, + }, + ); - server.setRequestHandler(InitializeRequestSchema, (request) => ({ + server.setRequestHandler(InitializeRequestSchema, (_request) => ({ protocolVersion: LATEST_PROTOCOL_VERSION, capabilities: { resources: {}, @@ -160,10 +181,15 @@ test("should respect server capabilities", async () => { const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + sampling: {}, + }, + ); await Promise.all([ client.connect(clientTransport), @@ -231,10 +257,15 @@ test("should typecheck", () => { WeatherRequest, WeatherNotification, WeatherResult - >({ - name: "WeatherClient", - version: "1.0.0", - }); + >( + { + name: "WeatherClient", + version: "1.0.0", + }, + { + sampling: {}, + }, + ); // Typecheck that only valid weather requests/notifications/results are allowed false && diff --git a/src/client/index.ts b/src/client/index.ts index b2ead25..83d7bf7 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -3,6 +3,7 @@ import { Transport } from "../shared/transport.js"; import { CallToolRequest, CallToolResultSchema, + ClientCapabilities, ClientNotification, ClientRequest, ClientResult, @@ -75,7 +76,10 @@ export class Client< /** * Initializes this client with the given name and version information. */ - constructor(private _clientInfo: Implementation) { + constructor( + private _clientInfo: Implementation, + private _capabilities: ClientCapabilities, + ) { super(); } @@ -88,7 +92,7 @@ export class Client< method: "initialize", params: { protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: {}, + capabilities: this._capabilities, clientInfo: this._clientInfo, }, }, diff --git a/src/server/index.test.ts b/src/server/index.test.ts index dd5c94a..f5d9311 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -9,6 +9,7 @@ import { ResultSchema, LATEST_PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS, + CreateMessageRequestSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; @@ -39,10 +40,18 @@ test("should accept latest protocol version", async () => { }), }; - const server = new Server({ - name: "test server", - version: "1.0", - }); + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + ); await server.connect(serverTransport); @@ -90,10 +99,18 @@ test("should accept supported older protocol version", async () => { }), }; - const server = new Server({ - name: "test server", - version: "1.0", - }); + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + ); await server.connect(serverTransport); @@ -140,10 +157,18 @@ test("should handle unsupported protocol version", async () => { }), }; - const server = new Server({ - name: "test server", - version: "1.0", - }); + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + ); await server.connect(serverTransport); @@ -166,14 +191,39 @@ test("should handle unsupported protocol version", async () => { }); test("should respect client capabilities", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + ); + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + sampling: {}, + }, + ); - const client = new Client({ - name: "test client", - version: "1.0", + // Implement request handler for sampling/createMessage + client.setRequestHandler(CreateMessageRequestSchema, async (request) => { + // Mock implementation of createMessage + return { + model: "test-model", + role: "assistant", + content: { + type: "text", + text: "This is a test response", + }, + }; }); const [clientTransport, serverTransport] = @@ -184,9 +234,17 @@ test("should respect client capabilities", async () => { server.connect(serverTransport), ]); - expect(server.getClientCapabilities()).toEqual({}); + expect(server.getClientCapabilities()).toEqual({ sampling: {} }); + + // This should work because sampling is supported by the client + await expect( + server.createMessage({ + messages: [], + maxTokens: 10, + }), + ).resolves.not.toThrow(); - // This should throw because roots are not supported by the client + // This should still throw because roots are not supported by the client await expect(server.listRoots()).rejects.toThrow( "Client does not support roots", ); @@ -237,10 +295,18 @@ test("should typecheck", () => { WeatherRequest, WeatherNotification, WeatherResult - >({ - name: "WeatherServer", - version: "1.0.0", - }); + >( + { + name: "WeatherServer", + version: "1.0.0", + }, + { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + ); // Typecheck that only valid weather requests/notifications/results are allowed weatherServer.setRequestHandler(GetWeatherRequestSchema, (request) => { diff --git a/src/server/index.ts b/src/server/index.ts index 203d3bd..f60562b 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -73,7 +73,10 @@ export class Server< /** * Initializes this server with the given name and version information. */ - constructor(private _serverInfo: Implementation) { + constructor( + private _serverInfo: Implementation, + private _capabilities: ServerCapabilities, + ) { super(); this.setRequestHandler(InitializeRequestSchema, (request) => @@ -116,28 +119,7 @@ export class Server< } private getCapabilities(): ServerCapabilities { - return { - prompts: this._requestHandlers.has( - ListPromptsRequestSchema.shape.method.value as string, - ) - ? {} - : undefined, - resources: this._requestHandlers.has( - ListResourcesRequestSchema.shape.method.value as string, - ) - ? {} - : undefined, - tools: this._requestHandlers.has( - ListToolsRequestSchema.shape.method.value as string, - ) - ? {} - : undefined, - logging: this._requestHandlers.has( - SetLevelRequestSchema.shape.method.value as string, - ) - ? {} - : undefined, - }; + return this._capabilities; } private assertCapability( From af03a74f5a54f2d2414fe192d25200202d31938a Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 11 Nov 2024 22:42:01 +0000 Subject: [PATCH 06/11] Fix missing argument in cli.ts --- src/cli.ts | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/cli.ts b/src/cli.ts index a7b6230..5d64087 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -68,10 +68,13 @@ async function runServer(port: number | null) { console.log("Got new SSE connection"); const transport = new SSEServerTransport("/message", res); - const server = new Server({ - name: "mcp-typescript test server", - version: "0.1.0", - }); + const server = new Server( + { + name: "mcp-typescript test server", + version: "0.1.0", + }, + {}, + ); servers.push(server); From 8e369b717169ea896e3bcd02a0e42eb486160cfa Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 11 Nov 2024 22:49:53 +0000 Subject: [PATCH 07/11] Lint --- src/server/index.ts | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/server/index.ts b/src/server/index.ts index f60562b..c8be405 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -10,11 +10,8 @@ import { InitializeRequestSchema, InitializeResult, LATEST_PROTOCOL_VERSION, - ListPromptsRequestSchema, - ListResourcesRequestSchema, ListRootsRequest, ListRootsResultSchema, - ListToolsRequestSchema, LoggingMessageNotification, Notification, Request, @@ -24,7 +21,6 @@ import { ServerNotification, ServerRequest, ServerResult, - SetLevelRequestSchema, SUPPORTED_PROTOCOL_VERSIONS, } from "../types.js"; From dc45f5d148fdb1cb42fe6319ee0a03c74996d93a Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 12 Nov 2024 10:57:59 +0000 Subject: [PATCH 08/11] Handle capabilities checks at the request() level, make non-strict by default --- src/cli.ts | 18 ++++--- src/client/index.test.ts | 27 +++++++--- src/client/index.ts | 112 +++++++++++++++++++++++++++++++-------- src/server/index.test.ts | 60 ++++++++++++--------- src/server/index.ts | 45 ++++++++++++++-- src/shared/protocol.ts | 29 +++++++++- src/types.ts | 46 ++++++++++------ 7 files changed, 254 insertions(+), 83 deletions(-) diff --git a/src/cli.ts b/src/cli.ts index 5d64087..d544497 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -23,7 +23,9 @@ async function runClient(url_or_command: string, args: string[]) { version: "0.1.0", }, { - sampling: {}, + capabilities: { + sampling: {}, + }, }, ); @@ -73,7 +75,9 @@ async function runServer(port: number | null) { name: "mcp-typescript test server", version: "0.1.0", }, - {}, + { + capabilities: {}, + }, ); servers.push(server); @@ -111,10 +115,12 @@ async function runServer(port: number | null) { version: "0.1.0", }, { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, }, ); diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 62e30b6..7cbafd1 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -46,7 +46,9 @@ test("should initialize with matching protocol version", async () => { version: "1.0", }, { - sampling: {}, + capabilities: { + sampling: {}, + }, }, ); @@ -93,7 +95,9 @@ test("should initialize with supported older protocol version", async () => { version: "1.0", }, { - sampling: {}, + capabilities: { + sampling: {}, + }, }, ); @@ -135,7 +139,9 @@ test("should reject unsupported protocol version", async () => { version: "1.0", }, { - sampling: {}, + capabilities: { + sampling: {}, + }, }, ); @@ -153,8 +159,10 @@ test("should respect server capabilities", async () => { version: "1.0", }, { - resources: {}, - tools: {}, + capabilities: { + resources: {}, + tools: {}, + }, }, ); @@ -187,7 +195,10 @@ test("should respect server capabilities", async () => { version: "1.0", }, { - sampling: {}, + capabilities: { + sampling: {}, + }, + enforceStrictCapabilities: true, }, ); @@ -263,7 +274,9 @@ test("should typecheck", () => { version: "1.0.0", }, { - sampling: {}, + capabilities: { + sampling: {}, + }, }, ); diff --git a/src/client/index.ts b/src/client/index.ts index 83d7bf7..402c313 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,4 +1,8 @@ -import { ProgressCallback, Protocol } from "../shared/protocol.js"; +import { + ProgressCallback, + Protocol, + ProtocolOptions, +} from "../shared/protocol.js"; import { Transport } from "../shared/transport.js"; import { CallToolRequest, @@ -36,6 +40,13 @@ import { UnsubscribeRequest, } from "../types.js"; +export type ClientOptions = ProtocolOptions & { + /** + * Capabilities to advertise as being supported by this client. + */ + capabilities: ClientCapabilities; +}; + /** * An MCP client on top of a pluggable transport. * @@ -72,15 +83,28 @@ export class Client< > { private _serverCapabilities?: ServerCapabilities; private _serverVersion?: Implementation; + private _capabilities: ClientCapabilities; /** * Initializes this client with the given name and version information. */ constructor( private _clientInfo: Implementation, - private _capabilities: ClientCapabilities, + options: ClientOptions, ) { - super(); + super(options); + this._capabilities = options.capabilities; + } + + protected assertCapability( + capability: keyof ServerCapabilities, + method: string, + ): void { + if (!this._serverCapabilities?.[capability]) { + throw new Error( + `Server does not support ${capability} (required for ${method})`, + ); + } } override async connect(transport: Transport): Promise { @@ -136,14 +160,69 @@ export class Client< return this._serverVersion; } - private assertCapability( - capability: keyof ServerCapabilities, - method: string, - ) { - if (!this._serverCapabilities?.[capability]) { - throw new Error( - `Server does not support ${capability} (required for ${method})`, - ); + protected assertCapabilityForMethod(method: RequestT["method"]): void { + switch (method as ClientRequest["method"]) { + case "logging/setLevel": + if (!this._serverCapabilities?.logging) { + throw new Error( + "Server does not support logging (required for logging/setLevel)", + ); + } + break; + + case "prompts/get": + case "prompts/list": + if (!this._serverCapabilities?.prompts) { + throw new Error( + `Server does not support prompts (required for ${method})`, + ); + } + break; + + case "resources/list": + case "resources/templates/list": + case "resources/read": + case "resources/subscribe": + case "resources/unsubscribe": + if (!this._serverCapabilities?.resources) { + throw new Error( + `Server does not support resources (required for ${method})`, + ); + } + + if ( + method === "resources/subscribe" && + !this._serverCapabilities.resources.subscribe + ) { + throw new Error("Server does not support resource subscriptions"); + } + + break; + + case "tools/call": + case "tools/list": + if (!this._serverCapabilities?.tools) { + throw new Error( + `Server does not support tools (required for ${method})`, + ); + } + break; + + case "completion/complete": + if (!this._serverCapabilities?.prompts) { + throw new Error( + "Server does not support prompts (required for completion/complete)", + ); + } + break; + + case "initialize": + // No specific capability required for initialize + break; + + case "ping": + // No specific capability required for ping + break; } } @@ -155,7 +234,6 @@ export class Client< params: CompleteRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("prompts", "completion/complete"); return this.request( { method: "completion/complete", params }, CompleteResultSchema, @@ -164,7 +242,6 @@ export class Client< } async setLoggingLevel(level: LoggingLevel) { - this.assertCapability("logging", "logging/setLevel"); return this.request( { method: "logging/setLevel", params: { level } }, EmptyResultSchema, @@ -175,7 +252,6 @@ export class Client< params: GetPromptRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("prompts", "prompts/get"); return this.request( { method: "prompts/get", params }, GetPromptResultSchema, @@ -187,7 +263,6 @@ export class Client< params?: ListPromptsRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("prompts", "prompts/list"); return this.request( { method: "prompts/list", params }, ListPromptsResultSchema, @@ -199,7 +274,6 @@ export class Client< params?: ListResourcesRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("resources", "resources/list"); return this.request( { method: "resources/list", params }, ListResourcesResultSchema, @@ -211,7 +285,6 @@ export class Client< params?: ListResourceTemplatesRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("resources", "resources/templates/list"); return this.request( { method: "resources/templates/list", params }, ListResourceTemplatesResultSchema, @@ -223,7 +296,6 @@ export class Client< params: ReadResourceRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("resources", "resources/read"); return this.request( { method: "resources/read", params }, ReadResourceResultSchema, @@ -232,7 +304,6 @@ export class Client< } async subscribeResource(params: SubscribeRequest["params"]) { - this.assertCapability("resources", "resources/subscribe"); return this.request( { method: "resources/subscribe", params }, EmptyResultSchema, @@ -240,7 +311,6 @@ export class Client< } async unsubscribeResource(params: UnsubscribeRequest["params"]) { - this.assertCapability("resources", "resources/unsubscribe"); return this.request( { method: "resources/unsubscribe", params }, EmptyResultSchema, @@ -254,7 +324,6 @@ export class Client< | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, onprogress?: ProgressCallback, ) { - this.assertCapability("tools", "tools/call"); return this.request( { method: "tools/call", params }, resultSchema, @@ -266,7 +335,6 @@ export class Client< params?: ListToolsRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("tools", "tools/list"); return this.request( { method: "tools/list", params }, ListToolsResultSchema, diff --git a/src/server/index.test.ts b/src/server/index.test.ts index f5d9311..274bb07 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -46,10 +46,12 @@ test("should accept latest protocol version", async () => { version: "1.0", }, { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, }, ); @@ -105,10 +107,12 @@ test("should accept supported older protocol version", async () => { version: "1.0", }, { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, }, ); @@ -163,10 +167,12 @@ test("should handle unsupported protocol version", async () => { version: "1.0", }, { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, }, ); @@ -197,19 +203,25 @@ test("should respect client capabilities", async () => { version: "1.0", }, { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + enforceStrictCapabilities: true, }, ); + const client = new Client( { name: "test client", version: "1.0", }, { - sampling: {}, + capabilities: { + sampling: {}, + }, }, ); @@ -245,9 +257,7 @@ test("should respect client capabilities", async () => { ).resolves.not.toThrow(); // This should still throw because roots are not supported by the client - await expect(server.listRoots()).rejects.toThrow( - "Client does not support roots", - ); + await expect(server.listRoots()).rejects.toThrow(/^Client does not support/); }); /* @@ -301,10 +311,12 @@ test("should typecheck", () => { version: "1.0.0", }, { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, }, ); diff --git a/src/server/index.ts b/src/server/index.ts index c8be405..4558b34 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,4 +1,8 @@ -import { ProgressCallback, Protocol } from "../shared/protocol.js"; +import { + ProgressCallback, + Protocol, + ProtocolOptions, +} from "../shared/protocol.js"; import { ClientCapabilities, CreateMessageRequest, @@ -24,6 +28,13 @@ import { SUPPORTED_PROTOCOL_VERSIONS, } from "../types.js"; +export type ServerOptions = ProtocolOptions & { + /** + * Capabilities to advertise as being supported by this server. + */ + capabilities: ClientCapabilities; +}; + /** * An MCP server on top of a pluggable transport. * @@ -60,6 +71,7 @@ export class Server< > { private _clientCapabilities?: ClientCapabilities; private _clientVersion?: Implementation; + private _capabilities: ServerCapabilities; /** * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). @@ -71,9 +83,10 @@ export class Server< */ constructor( private _serverInfo: Implementation, - private _capabilities: ServerCapabilities, + options: ServerOptions, ) { - super(); + super(options); + this._capabilities = options.capabilities; this.setRequestHandler(InitializeRequestSchema, (request) => this._oninitialize(request), @@ -83,6 +96,30 @@ export class Server< ); } + protected assertCapabilityForMethod(method: RequestT["method"]): void { + switch (method as ServerRequest["method"]) { + case "sampling/createMessage": + if (!this._clientCapabilities?.sampling) { + throw new Error( + `Client does not support sampling (required for ${method})`, + ); + } + break; + + case "roots/list": + if (!this._clientCapabilities?.roots) { + throw new Error( + `Client does not support listing roots (required for ${method})`, + ); + } + break; + + case "ping": + // No specific capability required for ping + break; + } + } + private async _oninitialize( request: InitializeRequest, ): Promise { @@ -137,7 +174,6 @@ export class Server< params: CreateMessageRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("sampling", "sampling/createMessage"); return this.request( { method: "sampling/createMessage", params }, CreateMessageResultSchema, @@ -149,7 +185,6 @@ export class Server< params?: ListRootsRequest["params"], onprogress?: ProgressCallback, ) { - this.assertCapability("roots", "roots/list"); return this.request( { method: "roots/list", params }, ListRootsResultSchema, diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 22d2503..ae30541 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -21,11 +21,23 @@ import { Transport } from "./transport.js"; */ export type ProgressCallback = (progress: Progress) => void; +/** + * Additional initialization options. + */ +export type ProtocolOptions = { + /** + * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. + * + * Currently this defaults to false, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to true. + */ + enforceStrictCapabilities?: boolean; +}; + /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ -export class Protocol< +export abstract class Protocol< SendRequestT extends Request, SendNotificationT extends Notification, SendResultT extends Result, @@ -70,7 +82,7 @@ export class Protocol< */ fallbackNotificationHandler?: (notification: Notification) => Promise; - constructor() { + constructor(private _options?: ProtocolOptions) { this.setNotificationHandler(ProgressNotificationSchema, (notification) => { this._onprogress(notification as unknown as ProgressNotification); }); @@ -245,6 +257,15 @@ export class Protocol< await this._transport?.close(); } + /** + * A method to check if a capability is supported by the remote side, for the given method to be called. + * + * This should be implemented by subclasses. + */ + protected abstract assertCapabilityForMethod( + method: SendRequestT["method"], + ): void; + /** * Sends a request and wait for a response, with optional progress notifications in the meantime (if supported by the server). * @@ -261,6 +282,10 @@ export class Protocol< return; } + if (this._options?.enforceStrictCapabilities === true) { + this.assertCapabilityForMethod(request.method); + } + const messageId = this._requestMessageId++; const jsonrpcRequest: JSONRPCRequest = { ...request, diff --git a/src/types.ts b/src/types.ts index 0ba6fa2..a0d2d80 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,7 +1,10 @@ import { z } from "zod"; export const LATEST_PROTOCOL_VERSION = "2024-11-05"; -export const SUPPORTED_PROTOCOL_VERSIONS = [LATEST_PROTOCOL_VERSION, "2024-10-07"]; +export const SUPPORTED_PROTOCOL_VERSIONS = [ + LATEST_PROTOCOL_VERSION, + "2024-10-07", +]; /* JSON-RPC types */ export const JSONRPC_VERSION = "2.0"; @@ -179,7 +182,7 @@ export const ClientCapabilitiesSchema = z z .object({ /** - * Whether the client supports notifications for changes to the roots list. + * Whether the client supports issuing notifications for changes to the roots list. */ listChanged: z.optional(z.boolean()), }) @@ -223,7 +226,7 @@ export const ServerCapabilitiesSchema = z z .object({ /** - * Whether this server supports notifications for changes to the prompt list. + * Whether this server supports issuing notifications for changes to the prompt list. */ listChanged: z.optional(z.boolean()), }) @@ -236,11 +239,12 @@ export const ServerCapabilitiesSchema = z z .object({ /** - * Whether this server supports subscribing to resource updates. + * Whether this server supports clients subscribing to resource updates. */ subscribe: z.optional(z.boolean()), + /** - * Whether this server supports notifications for changes to the resource list. + * Whether this server supports issuing notifications for changes to the resource list. */ listChanged: z.optional(z.boolean()), }) @@ -253,7 +257,7 @@ export const ServerCapabilitiesSchema = z z .object({ /** - * Whether this server supports notifications for changes to the tool list. + * Whether this server supports issuing notifications for changes to the tool list. */ listChanged: z.optional(z.boolean()), }) @@ -725,9 +729,11 @@ export const CallToolResultSchema = ResultSchema.extend({ /** * CallToolResultSchema extended with backwards compatibility to protocol version 2024-10-07. */ -export const CompatibilityCallToolResultSchema = CallToolResultSchema.or(ResultSchema.extend({ - toolResult: z.unknown(), -})); +export const CompatibilityCallToolResultSchema = CallToolResultSchema.or( + ResultSchema.extend({ + toolResult: z.unknown(), + }), +); /** * Used by the client to invoke a tool provided by the server. @@ -802,12 +808,14 @@ export const LoggingMessageNotificationSchema = NotificationSchema.extend({ /** * Hints to use for model selection. */ -export const ModelHintSchema = z.object({ - /** - * A hint for a model name. - */ - name: z.string().optional(), -}).passthrough(); +export const ModelHintSchema = z + .object({ + /** + * A hint for a model name. + */ + name: z.string().optional(), + }) + .passthrough(); /** * The server's preferences for model selection, requested of the client during sampling. @@ -886,7 +894,9 @@ export const CreateMessageResultSchema = ResultSchema.extend({ /** * The reason why sampling stopped. */ - stopReason: z.optional(z.enum(["endTurn", "stopSequence", "maxTokens"]).or(z.string())), + stopReason: z.optional( + z.enum(["endTurn", "stopSequence", "maxTokens"]).or(z.string()), + ), role: z.enum(["user", "assistant"]), content: z.discriminatedUnion("type", [ TextContentSchema, @@ -1156,7 +1166,9 @@ export type Tool = z.infer; export type ListToolsRequest = z.infer; export type ListToolsResult = z.infer; export type CallToolResult = z.infer; -export type CompatibilityCallToolResult = z.infer; +export type CompatibilityCallToolResult = z.infer< + typeof CompatibilityCallToolResultSchema +>; export type CallToolRequest = z.infer; export type ToolListChangedNotification = z.infer< typeof ToolListChangedNotificationSchema From 64c28628839cc961d6940dfd7da03f16d6c34a6b Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 12 Nov 2024 12:02:10 +0000 Subject: [PATCH 09/11] Check and add tests for notification capabilities too --- src/client/index.test.ts | 102 +++++++++++++++++++++++++++++++++++++++ src/client/index.ts | 22 +++++++++ src/server/index.test.ts | 33 +++++++++++++ src/server/index.ts | 56 ++++++++++++++++----- src/shared/protocol.ts | 13 +++++ 5 files changed, 214 insertions(+), 12 deletions(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 7cbafd1..8717430 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -223,6 +223,108 @@ test("should respect server capabilities", async () => { ); }); +test("should respect client notification capabilities", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: {}, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + roots: { + listChanged: true, + }, + }, + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // This should work because the client has the roots.listChanged capability + await expect(client.sendRootsListChanged()).resolves.not.toThrow(); + + // Create a new client without the roots.listChanged capability + const clientWithoutCapability = new Client( + { + name: "test client without capability", + version: "1.0", + }, + { + capabilities: {}, + enforceStrictCapabilities: true, + }, + ); + + await clientWithoutCapability.connect(clientTransport); + + // This should throw because the client doesn't have the roots.listChanged capability + await expect(clientWithoutCapability.sendRootsListChanged()).rejects.toThrow( + /^Client does not support/, + ); +}); + +test("should respect server notification capabilities", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + logging: {}, + resources: { + listChanged: true, + }, + }, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: {}, + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // These should work because the server has the corresponding capabilities + await expect( + server.sendLoggingMessage({ level: "info", data: "Test" }), + ).resolves.not.toThrow(); + await expect(server.sendResourceListChanged()).resolves.not.toThrow(); + + // This should throw because the server doesn't have the tools capability + await expect(server.sendToolListChanged()).rejects.toThrow( + "Server does not support notifying of tool list changes", + ); +}); + /* Test that custom request/notification/result schemas can be used with the Client class. */ diff --git a/src/client/index.ts b/src/client/index.ts index 402c313..313c121 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -226,6 +226,28 @@ export class Client< } } + protected assertNotificationCapability( + method: NotificationT["method"], + ): void { + switch (method as ClientNotification["method"]) { + case "notifications/roots/list_changed": + if (!this._capabilities.roots?.listChanged) { + throw new Error( + "Client does not support roots list changed notifications", + ); + } + break; + + case "notifications/initialized": + // No specific capability required for initialized + break; + + case "notifications/progress": + // Progress notifications are always allowed + break; + } + } + async ping() { return this.request({ method: "ping" }, EmptyResultSchema); } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 274bb07..1d9e019 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -260,6 +260,39 @@ test("should respect client capabilities", async () => { await expect(server.listRoots()).rejects.toThrow(/^Client does not support/); }); +test("should respect server notification capabilities", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + logging: {}, + }, + enforceStrictCapabilities: true, + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await server.connect(serverTransport); + + // This should work because logging is supported by the server + await expect( + server.sendLoggingMessage({ + level: "info", + data: "Test log message", + }), + ).resolves.not.toThrow(); + + // This should throw because resource notificaitons are not supported by the server + await expect( + server.sendResourceUpdated({ uri: "test://resource" }), + ).rejects.toThrow(/^Server does not support/); +}); + /* Test that custom request/notification/result schemas can be used with the Server class. */ diff --git a/src/server/index.ts b/src/server/index.ts index 4558b34..5d4e514 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -32,7 +32,7 @@ export type ServerOptions = ProtocolOptions & { /** * Capabilities to advertise as being supported by this server. */ - capabilities: ClientCapabilities; + capabilities: ServerCapabilities; }; /** @@ -120,6 +120,49 @@ export class Server< } } + protected assertNotificationCapability( + method: (ServerNotification | NotificationT)["method"], + ): void { + switch (method as ServerNotification["method"]) { + case "notifications/message": + if (!this._capabilities.logging) { + throw new Error( + `Server does not support logging (required for ${method})`, + ); + } + break; + + case "notifications/resources/updated": + case "notifications/resources/list_changed": + if (!this._capabilities.resources) { + throw new Error( + `Server does not support notifying about resources (required for ${method})`, + ); + } + break; + + case "notifications/tools/list_changed": + if (!this._capabilities.tools) { + throw new Error( + `Server does not support notifying of tool list changes (required for ${method})`, + ); + } + break; + + case "notifications/prompts/list_changed": + if (!this._capabilities.prompts) { + throw new Error( + `Server does not support notifying of prompt list changes (required for ${method})`, + ); + } + break; + + case "notifications/progress": + // Progress notifications are always allowed + break; + } + } + private async _oninitialize( request: InitializeRequest, ): Promise { @@ -155,17 +198,6 @@ export class Server< return this._capabilities; } - private assertCapability( - capability: keyof ClientCapabilities, - method: string, - ) { - if (!this._clientCapabilities?.[capability]) { - throw new Error( - `Client does not support ${capability} (required for ${method})`, - ); - } - } - async ping() { return this.request({ method: "ping" }, EmptyResultSchema); } diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index ae30541..8b1705b 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -28,6 +28,8 @@ export type ProtocolOptions = { /** * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. * + * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. + * * Currently this defaults to false, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to true. */ enforceStrictCapabilities?: boolean; @@ -266,6 +268,15 @@ export abstract class Protocol< method: SendRequestT["method"], ): void; + /** + * A method to check if a notification is supported by the local side, for the given method to be sent. + * + * This should be implemented by subclasses. + */ + protected abstract assertNotificationCapability( + method: SendNotificationT["method"], + ): void; + /** * Sends a request and wait for a response, with optional progress notifications in the meantime (if supported by the server). * @@ -326,6 +337,8 @@ export abstract class Protocol< throw new Error("Not connected"); } + this.assertNotificationCapability(notification.method); + const jsonrpcNotification: JSONRPCNotification = { ...notification, jsonrpc: "2.0", From 8b307a5d62e9b0a27950c5e6d81130fca1ecf0cd Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 12 Nov 2024 12:16:43 +0000 Subject: [PATCH 10/11] Check capabilities when request handlers are set --- src/client/index.test.ts | 33 +++++++++++++++++++++++++ src/client/index.ts | 20 +++++++++++++++ src/server/index.test.ts | 39 +++++++++++++++++++++++++++++ src/server/index.ts | 53 ++++++++++++++++++++++++++++++++++++++++ src/shared/protocol.ts | 11 ++++++++- 5 files changed, 155 insertions(+), 1 deletion(-) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 8717430..5610a62 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -12,6 +12,8 @@ import { InitializeRequestSchema, ListResourcesRequestSchema, ListToolsRequestSchema, + CreateMessageRequestSchema, + ListRootsRequestSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; import { Server } from "../server/index.js"; @@ -325,6 +327,37 @@ test("should respect server notification capabilities", async () => { ); }); +test("should only allow setRequestHandler for declared capabilities", () => { + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + // This should work because sampling is a declared capability + expect(() => { + client.setRequestHandler(CreateMessageRequestSchema, () => ({ + model: "test-model", + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + })); + }).not.toThrow(); + + // This should throw because roots listing is not a declared capability + expect(() => { + client.setRequestHandler(ListRootsRequestSchema, () => ({})); + }).toThrow("Client does not support roots capability"); +}); + /* Test that custom request/notification/result schemas can be used with the Client class. */ diff --git a/src/client/index.ts b/src/client/index.ts index 313c121..dd5d2ce 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -248,6 +248,26 @@ export class Client< } } + protected assertRequestHandlerCapability(method: string): void { + switch (method) { + case "sampling/createMessage": + if (!this._capabilities.sampling) { + throw new Error("Client does not support sampling capability"); + } + break; + + case "roots/list": + if (!this._capabilities.roots) { + throw new Error("Client does not support roots capability"); + } + break; + + case "ping": + // No specific capability required for ping + break; + } + } + async ping() { return this.request({ method: "ping" }, EmptyResultSchema); } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 1d9e019..d30c670 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -10,6 +10,10 @@ import { LATEST_PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS, CreateMessageRequestSchema, + ListPromptsRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, + SetLevelRequestSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; @@ -293,6 +297,41 @@ test("should respect server notification capabilities", async () => { ).rejects.toThrow(/^Server does not support/); }); +test("should only allow setRequestHandler for declared capabilities", () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + }, + }, + ); + + // These should work because the capabilities are declared + expect(() => { + server.setRequestHandler(ListPromptsRequestSchema, () => ({ prompts: [] })); + }).not.toThrow(); + + expect(() => { + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [], + })); + }).not.toThrow(); + + // These should throw because the capabilities are not declared + expect(() => { + server.setRequestHandler(ListToolsRequestSchema, () => ({ tools: [] })); + }).toThrow(/^Server does not support tools/); + + expect(() => { + server.setRequestHandler(SetLevelRequestSchema, () => ({})); + }).toThrow(/^Server does not support logging/); +}); + /* Test that custom request/notification/result schemas can be used with the Server class. */ diff --git a/src/server/index.ts b/src/server/index.ts index 5d4e514..ecb525b 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -163,6 +163,59 @@ export class Server< } } + protected assertRequestHandlerCapability(method: string): void { + switch (method) { + case "sampling/createMessage": + if (!this._capabilities.sampling) { + throw new Error( + `Server does not support sampling (required for ${method})`, + ); + } + break; + + case "logging/setLevel": + if (!this._capabilities.logging) { + throw new Error( + `Server does not support logging (required for ${method})`, + ); + } + break; + + case "prompts/get": + case "prompts/list": + if (!this._capabilities.prompts) { + throw new Error( + `Server does not support prompts (required for ${method})`, + ); + } + break; + + case "resources/list": + case "resources/templates/list": + case "resources/read": + if (!this._capabilities.resources) { + throw new Error( + `Server does not support resources (required for ${method})`, + ); + } + break; + + case "tools/call": + case "tools/list": + if (!this._capabilities.tools) { + throw new Error( + `Server does not support tools (required for ${method})`, + ); + } + break; + + case "ping": + case "initialize": + // No specific capability required for these methods + break; + } + } + private async _oninitialize( request: InitializeRequest, ): Promise { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 8b1705b..85610a9 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -277,6 +277,13 @@ export abstract class Protocol< method: SendNotificationT["method"], ): void; + /** + * A method to check if a request handler is supported by the local side, for the given method to be handled. + * + * This should be implemented by subclasses. + */ + protected abstract assertRequestHandlerCapability(method: string): void; + /** * Sends a request and wait for a response, with optional progress notifications in the meantime (if supported by the server). * @@ -360,7 +367,9 @@ export abstract class Protocol< requestSchema: T, handler: (request: z.infer) => SendResultT | Promise, ): void { - this._requestHandlers.set(requestSchema.shape.method.value, (request) => + const method = requestSchema.shape.method.value; + this.assertRequestHandlerCapability(method); + this._requestHandlers.set(method, (request) => Promise.resolve(handler(requestSchema.parse(request))), ); } From 70cfb0f0d458068af7721f8428f0531021143280 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 12 Nov 2024 12:17:35 +0000 Subject: [PATCH 11/11] Unify error text --- src/client/index.ts | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index dd5d2ce..e0df322 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -165,7 +165,7 @@ export class Client< case "logging/setLevel": if (!this._serverCapabilities?.logging) { throw new Error( - "Server does not support logging (required for logging/setLevel)", + `Server does not support logging (required for ${method})`, ); } break; @@ -194,7 +194,9 @@ export class Client< method === "resources/subscribe" && !this._serverCapabilities.resources.subscribe ) { - throw new Error("Server does not support resource subscriptions"); + throw new Error( + `Server does not support resource subscriptions (required for ${method})`, + ); } break; @@ -211,7 +213,7 @@ export class Client< case "completion/complete": if (!this._serverCapabilities?.prompts) { throw new Error( - "Server does not support prompts (required for completion/complete)", + `Server does not support prompts (required for ${method})`, ); } break; @@ -233,7 +235,7 @@ export class Client< case "notifications/roots/list_changed": if (!this._capabilities.roots?.listChanged) { throw new Error( - "Client does not support roots list changed notifications", + `Client does not support roots list changed notifications (required for ${method})`, ); } break; @@ -252,13 +254,17 @@ export class Client< switch (method) { case "sampling/createMessage": if (!this._capabilities.sampling) { - throw new Error("Client does not support sampling capability"); + throw new Error( + `Client does not support sampling capability (required for ${method})`, + ); } break; case "roots/list": if (!this._capabilities.roots) { - throw new Error("Client does not support roots capability"); + throw new Error( + `Client does not support roots capability (required for ${method})`, + ); } break;