Skip to content
This repository was archived by the owner on May 29, 2025. It is now read-only.

Commit 4cbe50a

Browse files
gsabranGui Sabran
andauthored
Add support for client capabilities (#5)
* nit: rename * move files * pass capability in initialization * handle server requests * make API more robust * lint --------- Co-authored-by: Gui Sabran <gsabran@www.com>
1 parent ef2617b commit 4cbe50a

29 files changed

+351
-129
lines changed

MCPClient/Sources/MCPClient.swift

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import Combine
33
import Foundation
44
import MCPShared
55

6+
public typealias SamplingRequestHandler = ((CreateMessageRequest.Params) async throws -> CreateMessageRequest.Result)
7+
public typealias ListRootsRequestHandler = ((ListRootsRequest.Params?) async throws -> ListRootsRequest.Result)
8+
69
// MARK: - MCPClient
710

811
public actor MCPClient: MCPClientInterface {
@@ -11,26 +14,30 @@ public actor MCPClient: MCPClientInterface {
1114

1215
public init(
1316
info: Implementation,
14-
capabilities: ClientCapabilities,
15-
transport: Transport)
17+
transport: Transport,
18+
capabilities: ClientCapabilityHandlers = .init())
1619
async throws {
1720
try await self.init(
18-
info: info,
19-
capabilities: capabilities,
20-
getMcpConnection: { try MCPConnection(
21+
samplingRequestHandler: capabilities.sampling?.handler,
22+
listRootRequestHandler: capabilities.roots?.handler,
23+
connection: try MCPClientConnection(
2124
info: info,
22-
capabilities: capabilities,
23-
transport: transport) })
25+
capabilities: ClientCapabilities(
26+
experimental: nil, // TODO: support experimental requests
27+
roots: capabilities.roots?.info,
28+
sampling: capabilities.sampling?.info),
29+
transport: transport))
2430
}
2531

2632
init(
27-
info _: Implementation,
28-
capabilities _: ClientCapabilities,
29-
getMcpConnection: @escaping () throws -> MCPConnectionInterface)
33+
samplingRequestHandler: SamplingRequestHandler? = nil,
34+
listRootRequestHandler: ListRootsRequestHandler? = nil,
35+
connection: MCPClientConnectionInterface)
3036
async throws {
31-
self.getMcpConnection = getMcpConnection
32-
3337
// Initialize the connection, and then update server capabilities.
38+
self.connection = connection
39+
self.samplingRequestHandler = samplingRequestHandler
40+
self.listRootRequestHandler = listRootRequestHandler
3441
try await connect()
3542
Task { try await self.updateTools() }
3643
Task { try await self.updatePrompts() }
@@ -111,23 +118,28 @@ public actor MCPClient: MCPClientInterface {
111118
return try await connectionInfo.connection.readResource(.init(uri: uri))
112119
}
113120

121+
// MARK: Internal
122+
123+
let connection: MCPClientConnectionInterface
124+
114125
// MARK: Private
115126

116127
private struct ConnectionInfo {
117-
let connection: MCPConnectionInterface
128+
let connection: MCPClientConnectionInterface
118129
let serverInfo: Implementation
119130
let serverCapabilities: ServerCapabilities
120131
}
121132

133+
private let samplingRequestHandler: SamplingRequestHandler?
134+
private let listRootRequestHandler: ListRootsRequestHandler?
135+
122136
private var connectionInfo: ConnectionInfo?
123137

124138
private let _tools = CurrentValueSubject<ServerCapabilityState<[Tool]>?, Never>(nil)
125139
private let _prompts = CurrentValueSubject<ServerCapabilityState<[Prompt]>?, Never>(nil)
126140
private let _resources = CurrentValueSubject<ServerCapabilityState<[Resource]>?, Never>(nil)
127141
private let _resourceTemplates = CurrentValueSubject<ServerCapabilityState<[ResourceTemplate]>?, Never>(nil)
128142

129-
private let getMcpConnection: () throws -> MCPConnectionInterface
130-
131143
private var progressHandlers = [String: (progress: Double, total: Double?) -> Void]()
132144

133145
private func startListeningToNotifications() async throws {
@@ -163,6 +175,52 @@ public actor MCPClient: MCPClientInterface {
163175
}
164176
}
165177

178+
private func startListeningToRequests() async throws {
179+
let connectionInfo = try getConnectionInfo()
180+
let requests = await connectionInfo.connection.requestsToHandle
181+
Task { [weak self] in
182+
for await(request, completion) in requests {
183+
guard let self else {
184+
completion(.failure(.init(
185+
code: JRPCErrorCodes.internalError.rawValue,
186+
message: "The client disconnected")))
187+
return
188+
}
189+
switch request {
190+
case .createMessage(let params):
191+
if let handler = await self.samplingRequestHandler {
192+
do {
193+
completion(.success(try await handler(params)))
194+
} catch {
195+
completion(.failure(.init(
196+
code: JRPCErrorCodes.internalError.rawValue,
197+
message: error.localizedDescription)))
198+
}
199+
} else {
200+
completion(.failure(.init(
201+
code: JRPCErrorCodes.invalidRequest.rawValue,
202+
message: "Sampling is not supported by this client")))
203+
}
204+
205+
case .listRoots(let params):
206+
if let handler = await self.listRootRequestHandler {
207+
do {
208+
completion(.success(try await handler(params)))
209+
} catch {
210+
completion(.failure(.init(
211+
code: JRPCErrorCodes.internalError.rawValue,
212+
message: error.localizedDescription)))
213+
}
214+
} else {
215+
completion(.failure(.init(
216+
code: JRPCErrorCodes.invalidRequest.rawValue,
217+
message: "Listing roots is not supported by this client")))
218+
}
219+
}
220+
}
221+
}
222+
}
223+
166224
private func startPinging() {
167225
// TODO
168226
}
@@ -212,19 +270,19 @@ public actor MCPClient: MCPClientInterface {
212270
}
213271

214272
private func connect() async throws {
215-
let mcpConnection = try getMcpConnection()
216-
let response = try await mcpConnection.initialize()
273+
let response = try await connection.initialize()
217274
guard response.protocolVersion == MCP.protocolVersion else {
218275
throw MCPClientError.versionMismatch
219276
}
220277

221278
connectionInfo = ConnectionInfo(
222-
connection: mcpConnection,
279+
connection: connection,
223280
serverInfo: response.serverInfo,
224281
serverCapabilities: response.capabilities)
225282

226-
try await mcpConnection.acknowledgeInitialization()
283+
try await connection.acknowledgeInitialization()
227284
try await startListeningToNotifications()
285+
try await startListeningToRequests()
228286
startPinging()
229287
}
230288

MCPClient/Sources/MCPConnection.swift renamed to MCPClient/Sources/MCPClientConnection.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ import OSLog
66

77
private let mcpLogger = Logger(subsystem: Bundle.main.bundleIdentifier.map { "\($0).mcp" } ?? "com.app.mcp", category: "mcp")
88

9-
// MARK: - MCPConnection
9+
// MARK: - MCPClientConnection
1010

11-
public actor MCPConnection: MCPConnectionInterface {
11+
public actor MCPClientConnection: MCPClientConnectionInterface {
1212

1313
// MARK: Lifecycle
1414

MCPClient/Sources/MCPConnectionInterface.swift renamed to MCPClient/Sources/MCPClientConnectionInterface.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ public typealias AnyJRPCResponse = Swift.Result<Encodable & Sendable, AnyJSONRPC
66

77
public typealias HandleServerRequest = (ServerRequest, (AnyJRPCResponse) -> Void)
88

9-
// MARK: - MCPConnectionInterface
9+
// MARK: - MCPClientConnectionInterface
1010

1111
/// The MCP JRPC Bridge is a stateless interface to the MCP server that provides a higher level Swift interface.
1212
/// It does not implement any of the stateful behaviors of the MCP server, such as subscribing to changes, detecting connection health,
1313
/// ensuring that the connection has been initialized before being used etc.
1414
///
1515
/// For most use cases, `MCPClient` should be a preferred interface.
16-
public protocol MCPConnectionInterface {
16+
public protocol MCPClientConnectionInterface {
1717
/// The notifications received by the server.
1818
var notifications: AsyncStream<ServerNotification> { get async }
1919
// TODO: look at moving the request handler to the init

MCPClient/Sources/MCPClientInterface.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
import JSONRPC
22
import MCPShared
3+
import MemberwiseInit
34

45
// MARK: - MCPClientInterface
56

67
public protocol MCPClientInterface { }
78

89
public typealias Transport = DataChannel
910

11+
// MARK: - ClientCapabilityHandlers
12+
13+
/// Describes the supported capabilities of an MCP client, and how to handle each of the supported ones.
14+
///
15+
/// Note: This is similar to `ClientCapabilities`, with the addition of the handler function.
16+
@MemberwiseInit(.public, _optionalsDefaultNil: true)
17+
public struct ClientCapabilityHandlers {
18+
public let roots: CapabilityHandler<ListChangedCapability, ListRootsRequestHandler>?
19+
public let sampling: CapabilityHandler<EmptyObject, SamplingRequestHandler>?
20+
// TODO: add experimental
21+
}
22+
1023
// MARK: - MCPClientError
1124

1225
public enum MCPClientError: Error {

MCPClient/Sources/MockMCPConnection.swift

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import MCPShared
44
#if DEBUG
55
// TODO: move to a test helper package
66

7-
/// A mock `MCPConnection` that can be used in tests.
8-
class MockMCPConnection: MCPConnectionInterface {
7+
/// A mock `MCPClientConnection` that can be used in tests.
8+
class MockMCPClientConnection: MCPClientConnectionInterface {
99

1010
// MARK: Lifecycle
1111

@@ -89,109 +89,109 @@ class MockMCPConnection: MCPConnectionInterface {
8989
if let initializeStub {
9090
return try await initializeStub()
9191
}
92-
throw MockMCPConnectionError.notImplemented(function: "initialize")
92+
throw MockMCPClientConnectionError.notImplemented(function: "initialize")
9393
}
9494

9595
func acknowledgeInitialization() async throws {
9696
if let acknowledgeInitializationStub {
9797
return try await acknowledgeInitializationStub()
9898
}
99-
throw MockMCPConnectionError.notImplemented(function: "acknowledgeInitialization")
99+
throw MockMCPClientConnectionError.notImplemented(function: "acknowledgeInitialization")
100100
}
101101

102102
func ping() async throws {
103103
if let pingStub {
104104
return try await pingStub()
105105
}
106-
throw MockMCPConnectionError.notImplemented(function: "ping")
106+
throw MockMCPClientConnectionError.notImplemented(function: "ping")
107107
}
108108

109109
func listPrompts() async throws -> [Prompt] {
110110
if let listPromptsStub {
111111
return try await listPromptsStub()
112112
}
113-
throw MockMCPConnectionError.notImplemented(function: "listPrompts")
113+
throw MockMCPClientConnectionError.notImplemented(function: "listPrompts")
114114
}
115115

116116
func getPrompt(_ params: GetPromptRequest.Params) async throws -> GetPromptRequest.Result {
117117
if let getPromptStub {
118118
return try await getPromptStub(params)
119119
}
120-
throw MockMCPConnectionError.notImplemented(function: "getPrompt")
120+
throw MockMCPClientConnectionError.notImplemented(function: "getPrompt")
121121
}
122122

123123
func listResources() async throws -> [Resource] {
124124
if let listResourcesStub {
125125
return try await listResourcesStub()
126126
}
127-
throw MockMCPConnectionError.notImplemented(function: "listResources")
127+
throw MockMCPClientConnectionError.notImplemented(function: "listResources")
128128
}
129129

130130
func readResource(_ params: ReadResourceRequest.Params) async throws -> ReadResourceRequest.Result {
131131
if let readResourceStub {
132132
return try await readResourceStub(params)
133133
}
134-
throw MockMCPConnectionError.notImplemented(function: "readResource")
134+
throw MockMCPClientConnectionError.notImplemented(function: "readResource")
135135
}
136136

137137
func subscribeToUpdateToResource(_ params: SubscribeRequest.Params) async throws {
138138
if let subscribeToUpdateToResourceStub {
139139
return try await subscribeToUpdateToResourceStub(params)
140140
}
141-
throw MockMCPConnectionError.notImplemented(function: "subscribeToUpdateToResource")
141+
throw MockMCPClientConnectionError.notImplemented(function: "subscribeToUpdateToResource")
142142
}
143143

144144
func unsubscribeToUpdateToResource(_ params: UnsubscribeRequest.Params) async throws {
145145
if let unsubscribeToUpdateToResourceStub {
146146
return try await unsubscribeToUpdateToResourceStub(params)
147147
}
148-
throw MockMCPConnectionError.notImplemented(function: "unsubscribeToUpdateToResource")
148+
throw MockMCPClientConnectionError.notImplemented(function: "unsubscribeToUpdateToResource")
149149
}
150150

151151
func listResourceTemplates() async throws -> [ResourceTemplate] {
152152
if let listResourceTemplatesStub {
153153
return try await listResourceTemplatesStub()
154154
}
155-
throw MockMCPConnectionError.notImplemented(function: "listResourceTemplates")
155+
throw MockMCPClientConnectionError.notImplemented(function: "listResourceTemplates")
156156
}
157157

158158
func listTools() async throws -> [Tool] {
159159
if let listToolsStub {
160160
return try await listToolsStub()
161161
}
162-
throw MockMCPConnectionError.notImplemented(function: "listTools")
162+
throw MockMCPClientConnectionError.notImplemented(function: "listTools")
163163
}
164164

165165
func call(toolName: String, arguments: JSON?, progressToken: ProgressToken?) async throws -> CallToolRequest.Result {
166166
if let callToolStub {
167167
return try await callToolStub(toolName, arguments, progressToken)
168168
}
169-
throw MockMCPConnectionError.notImplemented(function: "callTool")
169+
throw MockMCPClientConnectionError.notImplemented(function: "callTool")
170170
}
171171

172172
func requestCompletion(_ params: CompleteRequest.Params) async throws -> CompleteRequest.Result {
173173
if let requestCompletionStub {
174174
return try await requestCompletionStub(params)
175175
}
176-
throw MockMCPConnectionError.notImplemented(function: "requestCompletion")
176+
throw MockMCPClientConnectionError.notImplemented(function: "requestCompletion")
177177
}
178178

179179
func setLogLevel(_ params: SetLevelRequest.Params) async throws -> SetLevelRequest.Result {
180180
if let setLogLevelStub {
181181
return try await setLogLevelStub(params)
182182
}
183-
throw MockMCPConnectionError.notImplemented(function: "setLogLevel")
183+
throw MockMCPClientConnectionError.notImplemented(function: "setLogLevel")
184184
}
185185

186186
func log(_ params: LoggingMessageNotification.Params) async throws {
187187
if let logStub {
188188
return try await logStub(params)
189189
}
190-
throw MockMCPConnectionError.notImplemented(function: "log")
190+
throw MockMCPClientConnectionError.notImplemented(function: "log")
191191
}
192192
}
193193

194-
enum MockMCPConnectionError: Error {
194+
enum MockMCPClientConnectionError: Error {
195195
case notImplemented(function: String)
196196
}
197197

0 commit comments

Comments
 (0)