Skip to content

Commit

Permalink
Handle capabilities checks at the request() level, make non-strict by…
Browse files Browse the repository at this point in the history
… default
  • Loading branch information
jspahrsummers committed Nov 12, 2024
1 parent 8e369b7 commit dc45f5d
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 83 deletions.
18 changes: 12 additions & 6 deletions src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ async function runClient(url_or_command: string, args: string[]) {
version: "0.1.0",
},
{
sampling: {},
capabilities: {
sampling: {},
},
},
);

Expand Down Expand Up @@ -73,7 +75,9 @@ async function runServer(port: number | null) {
name: "mcp-typescript test server",
version: "0.1.0",
},
{},
{
capabilities: {},
},
);

servers.push(server);
Expand Down Expand Up @@ -111,10 +115,12 @@ async function runServer(port: number | null) {
version: "0.1.0",
},
{
prompts: {},
resources: {},
tools: {},
logging: {},
capabilities: {
prompts: {},
resources: {},
tools: {},
logging: {},
},
},
);

Expand Down
27 changes: 20 additions & 7 deletions src/client/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ test("should initialize with matching protocol version", async () => {
version: "1.0",
},
{
sampling: {},
capabilities: {
sampling: {},
},
},
);

Expand Down Expand Up @@ -93,7 +95,9 @@ test("should initialize with supported older protocol version", async () => {
version: "1.0",
},
{
sampling: {},
capabilities: {
sampling: {},
},
},
);

Expand Down Expand Up @@ -135,7 +139,9 @@ test("should reject unsupported protocol version", async () => {
version: "1.0",
},
{
sampling: {},
capabilities: {
sampling: {},
},
},
);

Expand All @@ -153,8 +159,10 @@ test("should respect server capabilities", async () => {
version: "1.0",
},
{
resources: {},
tools: {},
capabilities: {
resources: {},
tools: {},
},
},
);

Expand Down Expand Up @@ -187,7 +195,10 @@ test("should respect server capabilities", async () => {
version: "1.0",
},
{
sampling: {},
capabilities: {
sampling: {},
},
enforceStrictCapabilities: true,
},
);

Expand Down Expand Up @@ -263,7 +274,9 @@ test("should typecheck", () => {
version: "1.0.0",
},
{
sampling: {},
capabilities: {
sampling: {},
},
},
);

Expand Down
112 changes: 90 additions & 22 deletions src/client/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { ProgressCallback, Protocol } from "../shared/protocol.js";
import {
ProgressCallback,
Protocol,
ProtocolOptions,
} from "../shared/protocol.js";
import { Transport } from "../shared/transport.js";
import {
CallToolRequest,
Expand Down Expand Up @@ -36,6 +40,13 @@ import {
UnsubscribeRequest,
} from "../types.js";

export type ClientOptions = ProtocolOptions & {
/**
* Capabilities to advertise as being supported by this client.
*/
capabilities: ClientCapabilities;
};

/**
* An MCP client on top of a pluggable transport.
*
Expand Down Expand Up @@ -72,15 +83,28 @@ export class Client<
> {
private _serverCapabilities?: ServerCapabilities;
private _serverVersion?: Implementation;
private _capabilities: ClientCapabilities;

/**
* Initializes this client with the given name and version information.
*/
constructor(
private _clientInfo: Implementation,
private _capabilities: ClientCapabilities,
options: ClientOptions,
) {
super();
super(options);
this._capabilities = options.capabilities;
}

protected assertCapability(
capability: keyof ServerCapabilities,
method: string,
): void {
if (!this._serverCapabilities?.[capability]) {
throw new Error(
`Server does not support ${capability} (required for ${method})`,
);
}
}

override async connect(transport: Transport): Promise<void> {
Expand Down Expand Up @@ -136,14 +160,69 @@ export class Client<
return this._serverVersion;
}

private assertCapability(
capability: keyof ServerCapabilities,
method: string,
) {
if (!this._serverCapabilities?.[capability]) {
throw new Error(
`Server does not support ${capability} (required for ${method})`,
);
protected assertCapabilityForMethod(method: RequestT["method"]): void {
switch (method as ClientRequest["method"]) {
case "logging/setLevel":
if (!this._serverCapabilities?.logging) {
throw new Error(
"Server does not support logging (required for logging/setLevel)",
);
}
break;

case "prompts/get":
case "prompts/list":
if (!this._serverCapabilities?.prompts) {
throw new Error(
`Server does not support prompts (required for ${method})`,
);
}
break;

case "resources/list":
case "resources/templates/list":
case "resources/read":
case "resources/subscribe":
case "resources/unsubscribe":
if (!this._serverCapabilities?.resources) {
throw new Error(
`Server does not support resources (required for ${method})`,
);
}

if (
method === "resources/subscribe" &&
!this._serverCapabilities.resources.subscribe
) {
throw new Error("Server does not support resource subscriptions");
}

break;

case "tools/call":
case "tools/list":
if (!this._serverCapabilities?.tools) {
throw new Error(
`Server does not support tools (required for ${method})`,
);
}
break;

case "completion/complete":
if (!this._serverCapabilities?.prompts) {
throw new Error(
"Server does not support prompts (required for completion/complete)",
);
}
break;

case "initialize":
// No specific capability required for initialize
break;

case "ping":
// No specific capability required for ping
break;
}
}

Expand All @@ -155,7 +234,6 @@ export class Client<
params: CompleteRequest["params"],
onprogress?: ProgressCallback,
) {
this.assertCapability("prompts", "completion/complete");
return this.request(
{ method: "completion/complete", params },
CompleteResultSchema,
Expand All @@ -164,7 +242,6 @@ export class Client<
}

async setLoggingLevel(level: LoggingLevel) {
this.assertCapability("logging", "logging/setLevel");
return this.request(
{ method: "logging/setLevel", params: { level } },
EmptyResultSchema,
Expand All @@ -175,7 +252,6 @@ export class Client<
params: GetPromptRequest["params"],
onprogress?: ProgressCallback,
) {
this.assertCapability("prompts", "prompts/get");
return this.request(
{ method: "prompts/get", params },
GetPromptResultSchema,
Expand All @@ -187,7 +263,6 @@ export class Client<
params?: ListPromptsRequest["params"],
onprogress?: ProgressCallback,
) {
this.assertCapability("prompts", "prompts/list");
return this.request(
{ method: "prompts/list", params },
ListPromptsResultSchema,
Expand All @@ -199,7 +274,6 @@ export class Client<
params?: ListResourcesRequest["params"],
onprogress?: ProgressCallback,
) {
this.assertCapability("resources", "resources/list");
return this.request(
{ method: "resources/list", params },
ListResourcesResultSchema,
Expand All @@ -211,7 +285,6 @@ export class Client<
params?: ListResourceTemplatesRequest["params"],
onprogress?: ProgressCallback,
) {
this.assertCapability("resources", "resources/templates/list");
return this.request(
{ method: "resources/templates/list", params },
ListResourceTemplatesResultSchema,
Expand All @@ -223,7 +296,6 @@ export class Client<
params: ReadResourceRequest["params"],
onprogress?: ProgressCallback,
) {
this.assertCapability("resources", "resources/read");
return this.request(
{ method: "resources/read", params },
ReadResourceResultSchema,
Expand All @@ -232,15 +304,13 @@ export class Client<
}

async subscribeResource(params: SubscribeRequest["params"]) {
this.assertCapability("resources", "resources/subscribe");
return this.request(
{ method: "resources/subscribe", params },
EmptyResultSchema,
);
}

async unsubscribeResource(params: UnsubscribeRequest["params"]) {
this.assertCapability("resources", "resources/unsubscribe");
return this.request(
{ method: "resources/unsubscribe", params },
EmptyResultSchema,
Expand All @@ -254,7 +324,6 @@ export class Client<
| typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
onprogress?: ProgressCallback,
) {
this.assertCapability("tools", "tools/call");
return this.request(
{ method: "tools/call", params },
resultSchema,
Expand All @@ -266,7 +335,6 @@ export class Client<
params?: ListToolsRequest["params"],
onprogress?: ProgressCallback,
) {
this.assertCapability("tools", "tools/list");
return this.request(
{ method: "tools/list", params },
ListToolsResultSchema,
Expand Down
Loading

0 comments on commit dc45f5d

Please sign in to comment.