diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb index 029808b..a7f3b5f 100644 --- a/lib/mcp/server.rb +++ b/lib/mcp/server.rb @@ -148,10 +148,11 @@ def handle_json_request(request, headers: {}) # Handle incoming JSON-RPC request def handle_request(json_str, headers: {}) # rubocop:disable Metrics/MethodLength + client_id = headers['client_id'] begin request = JSON.parse(json_str) rescue JSON::ParserError, TypeError - return send_error(-32_600, 'Invalid Request', nil) + return send_error(-32_600, client_id, 'Invalid Request', nil) end @logger.debug("Received request: #{request.inspect}") @@ -161,38 +162,38 @@ def handle_request(json_str, headers: {}) # rubocop:disable Metrics/MethodLength id = request['id'] # Check if it's a valid JSON-RPC 2.0 request - return send_error(-32_600, 'Invalid Request', id) unless request['jsonrpc'] == '2.0' + return send_error(-32_600, client_id, 'Invalid Request', id) unless request['jsonrpc'] == '2.0' case method when 'ping' - send_result({}, id) + send_result(client_id, {}, id) when 'initialize' - handle_initialize(params, id) + handle_initialize(params, headers, id) when 'notifications/initialized' handle_initialized_notification when 'tools/list' - handle_tools_list(id) + handle_tools_list(headers, id) when 'tools/call' handle_tools_call(params, headers, id) when 'resources/list' - handle_resources_list(id) + handle_resources_list(headers, id) when 'resources/templates/list' - handle_resources_templates_list(id) + handle_resources_templates_list(headers, id) when 'resources/read' - handle_resources_read(params, id) + handle_resources_read(params, headers, id) when 'resources/subscribe' - handle_resources_subscribe(params, id) + handle_resources_subscribe(params, headers, id) when 'resources/unsubscribe' - handle_resources_unsubscribe(params, id) + handle_resources_unsubscribe(params, headers, id) when nil # This is a notification response, we don't need to handle it nil else - send_error(-32_601, "Method not found: #{method}", id) + send_error(-32_601, client_id, "Method not found: #{method}", id) end rescue StandardError => e @logger.error("Error handling request: #{e.message}, #{e.backtrace.join("\n")}") - send_error(-32_600, "Internal error: #{e.message}, #{e.backtrace.join("\n")}", id) + send_error(-32_600, client_id, "Internal error: #{e.message}, #{e.backtrace.join("\n")}", id) end # Notify subscribers about a resource update @@ -222,10 +223,11 @@ def read_resource(uri) PROTOCOL_VERSION = '2024-11-05' - def handle_initialize(params, id) + def handle_initialize(params, headers, id) # Store client capabilities for later use @client_capabilities = params['capabilities'] || {} client_info = params['clientInfo'] || {} + client_id = headers['client_id'] # Log client information @logger.info("Client connected: #{client_info['name']} v#{client_info['version']}") @@ -243,20 +245,21 @@ def handle_initialize(params, id) @logger.info("Server response: #{response.inspect}") - send_result(response, id) + send_result(client_id, response, id) end # Handle a resource read - def handle_resources_read(params, id) + def handle_resources_read(params, headers, id) uri = params['uri'] + client_id = headers['client_id'] - return send_error(-32_602, 'Invalid params: missing resource URI', id) unless uri + return send_error(-32_602, client_id, 'Invalid params: missing resource URI', id) unless uri @logger.debug("Looking for resource with URI: #{uri}") begin resource = read_resource(uri) - return send_error(-32_602, "Resource not found: #{uri}", id) unless resource + return send_error(-32_602, client_id, "Resource not found: #{uri}", id) unless resource @logger.debug("Found resource: #{resource.resource_name}, templated: #{resource.templated?}") @@ -278,7 +281,7 @@ def handle_resources_read(params, id) # # rescue StandardError => e # @logger.error("Error reading resource: #{e.message}") # @logger.error(e.backtrace.join("\n")) - send_result(result, id) + send_result(client_id, result, id) end end @@ -292,7 +295,8 @@ def handle_initialized_notification end # Handle tools/list request - def handle_tools_list(id) + def handle_tools_list(headers, id) + client_id = headers['client_id'] tools_list = @tools.values.map do |tool| { name: tool.tool_name, @@ -301,18 +305,19 @@ def handle_tools_list(id) } end - send_result({ tools: tools_list }, id) + send_result(client_id, { tools: tools_list }, id) end # Handle tools/call request def handle_tools_call(params, headers, id) tool_name = params['name'] arguments = params['arguments'] || {} + client_id = headers['client_id'] - return send_error(-32_602, 'Invalid params: missing tool name', id) unless tool_name + return send_error(-32_602, client_id, 'Invalid params: missing tool name', id) unless tool_name tool = @tools[tool_name] - return send_error(-32_602, "Tool not found: #{tool_name}", id) unless tool + return send_error(-32_602, client_id, "Tool not found: #{tool_name}", id) unless tool begin # Convert string keys to symbols for Ruby @@ -321,26 +326,26 @@ def handle_tools_call(params, headers, id) tool_instance = tool.new(headers: headers) authorized = tool_instance.authorized?(**symbolized_args) - return send_error(-32_602, 'Unauthorized', id) unless authorized + return send_error(-32_602, client_id, 'Unauthorized', id) unless authorized result, metadata = tool_instance.call_with_schema_validation!(**symbolized_args) # Format and send the result - send_formatted_result(result, id, metadata) + send_formatted_result(client_id, result, id, metadata) rescue FastMcp::Tool::InvalidArgumentsError => e @logger.error("Invalid arguments for tool #{tool_name}: #{e.message}") - send_error_result(e.message, id) + send_error_result(client_id, e.message, id) rescue StandardError => e @logger.error("Error calling tool #{tool_name}: #{e.message}") - send_error_result("#{e.message}, #{e.backtrace.join("\n")}", id) + send_error_result(client_id, "#{e.message}, #{e.backtrace.join("\n")}", id) end end # Format and send successful result - def send_formatted_result(result, id, metadata) + def send_formatted_result(client_id, result, id, metadata) # Check if the result is already in the expected format if result.is_a?(Hash) && result.key?(:content) - send_result(result, id, metadata: metadata) + send_result(client_id, result, id, metadata: metadata) else # Format the result according to the MCP specification formatted_result = { @@ -348,65 +353,69 @@ def send_formatted_result(result, id, metadata) isError: false } - send_result(formatted_result, id, metadata: metadata) + send_result(client_id, formatted_result, id, metadata: metadata) end end # Format and send error result - def send_error_result(message, id) + def send_error_result(client_id, message, id) # Format error according to the MCP specification error_result = { content: [{ type: 'text', text: "Error: #{message}" }], isError: true } - send_result(error_result, id) + send_result(client_id, error_result, id) end # Handle resources/list request - def handle_resources_list(id) + def handle_resources_list(headers, id) + client_id = headers['client_id'] resources_list = @resources.select(&:non_templated?).map(&:metadata) - send_result({ resources: resources_list }, id) + send_result(client_id, { resources: resources_list }, id) end # Handle resources/templates/list request - def handle_resources_templates_list(id) + def handle_resources_templates_list(headers, id) + client_id = headers['client_id'] # Collect templated resources templated_resources_list = @resources.select(&:templated?).map(&:metadata) - send_result({ resourceTemplates: templated_resources_list }, id) + send_result(client_id, { resourceTemplates: templated_resources_list }, id) end # Handle resources/subscribe request - def handle_resources_subscribe(params, id) + def handle_resources_subscribe(params, headers, id) return unless @client_initialized uri = params['uri'] + client_id = headers['client_id'] unless uri - send_error(-32_602, 'Invalid params: missing resource URI', id) + send_error(-32_602, client_id, 'Invalid params: missing resource URI', id) return end resource = @resources.find { |r| r.match(uri) } - return send_error(-32_602, "Resource not found: #{uri}", id) unless resource + return send_error(-32_602, client_id, "Resource not found: #{uri}", id) unless resource # Add to subscriptions @resource_subscriptions[uri] ||= [] @resource_subscriptions[uri] << id - send_result({ subscribed: true }, id) + send_result(client_id, { subscribed: true }, id) end # Handle resources/unsubscribe request - def handle_resources_unsubscribe(params, id) + def handle_resources_unsubscribe(params, headers, id) return unless @client_initialized uri = params['uri'] + client_id = headers['client_id'] unless uri - send_error(-32_602, 'Invalid params: missing resource URI', id) + send_error(-32_602, client_id, 'Invalid params: missing resource URI', id) return end @@ -416,7 +425,7 @@ def handle_resources_unsubscribe(params, id) @resource_subscriptions.delete(uri) if @resource_subscriptions[uri].empty? end - send_result({ unsubscribed: true }, id) + send_result(client_id, { unsubscribed: true }, id) end # Notify clients about resource list changes @@ -433,7 +442,7 @@ def notify_resource_list_changed end # Send a JSON-RPC result response - def send_result(result, id, metadata: {}) + def send_result(client_id, result, id, metadata: {}) result[:_meta] = metadata if metadata.is_a?(Hash) && !metadata.empty? response = { @@ -443,11 +452,11 @@ def send_result(result, id, metadata: {}) } @logger.info("Sending result: #{response.inspect}") - send_response(response) + send_response(client_id, response) end # Send a JSON-RPC error response - def send_error(code, message, id = nil) + def send_error(code, client_id, message, id = nil) response = { jsonrpc: '2.0', error: { @@ -457,14 +466,14 @@ def send_error(code, message, id = nil) id: id } - send_response(response) + send_response(client_id, response) end # Send a JSON-RPC response - def send_response(response) + def send_response(client_id, response) if @transport @logger.debug("Sending response: #{response.inspect}") - @transport.send_message(response) + @transport.send_message_to(client_id, response) else @logger.warn("No transport available to send response: #{response.inspect}") @logger.warn("Transport: #{@transport.inspect}, transport_klass: #{@transport_klass.inspect}") diff --git a/lib/mcp/transports/base_transport.rb b/lib/mcp/transports/base_transport.rb index 33a462b..1857e31 100644 --- a/lib/mcp/transports/base_transport.rb +++ b/lib/mcp/transports/base_transport.rb @@ -30,6 +30,10 @@ def send_message(message) raise NotImplementedError, "#{self.class} must implement #send_message" end + def send_message_to(client_id, message) + raise NotImplementedError, "#{self.class} must implement #send_message_to" + end + # Process an incoming message # This is a helper method that can be used by subclasses def process_message(message, headers: {}) diff --git a/lib/mcp/transports/rack_transport.rb b/lib/mcp/transports/rack_transport.rb index f1d6a43..fcdf18d 100644 --- a/lib/mcp/transports/rack_transport.rb +++ b/lib/mcp/transports/rack_transport.rb @@ -74,31 +74,41 @@ def stop def send_message(message) json_message = message.is_a?(String) ? message : JSON.generate(message) @logger.debug("Broadcasting message to #{@sse_clients.size} SSE clients: #{json_message}") + clients_to_message = @sse_clients.keys - clients_to_remove = [] - @sse_clients_mutex.synchronize do - @sse_clients.each do |client_id, client| - stream = client[:stream] - mutex = client[:mutex] - next if stream.nil? || (stream.respond_to?(:closed?) && stream.closed?) || mutex.nil? - - begin - mutex.synchronize do - stream.write("data: #{json_message}\n\n") - stream.flush if stream.respond_to?(:flush) - end - rescue Errno::EPIPE, IOError => e - @logger.info("Client #{client_id} disconnected: #{e.message}") - clients_to_remove << client_id - rescue StandardError => e - @logger.error("Error sending message to client #{client_id}: #{e.message}") - clients_to_remove << client_id - end - end + clients_to_message.each do |client_id| + send_message_to(client_id, message) end + end - # Remove disconnected clients outside the loop to avoid modifying the hash during iteration - clients_to_remove.each { |client_id| unregister_sse_client(client_id) } + # Send a message to a specific SSE client + def send_message_to(client_id, message) + client = @sse_clients[client_id] + if client.nil? + @logger.info("Client #{client_id} not found, skipping message") + return + end + + json_message = message.is_a?(String) ? message : JSON.generate(message) + stream = client[:stream] + + if stream.nil? || (stream.respond_to?(:closed?) && stream.closed?) + unregister_sse_client(client_id) + else + client[:mutex].synchronize do + stream.write("data: #{json_message}\n\n") + stream.flush if stream.respond_to?(:flush) + end + end + nil + rescue Errno::EPIPE, IOError => e + @logger.info("Client #{client_id} disconnected: #{e.message}") + unregister_sse_client(client_id) + nil + rescue StandardError => e + @logger.error("Error sending message to client #{client_id}: #{e.message}") + unregister_sse_client(client_id) + nil end # Register a new SSE client @@ -111,6 +121,11 @@ def register_sse_client(client_id, stream, mutex = nil) # Unregister an SSE client def unregister_sse_client(client_id) + existing_client = @sse_clients[client_id] + return unless existing_client + + existing_client[:stream].close if existing_client[:stream].respond_to?(:close) + @sse_clients_mutex.synchronize do @logger.info("Unregistering SSE client: #{client_id}") @sse_clients.delete(client_id) @@ -318,6 +333,7 @@ def setup_cors_headers def extract_client_id(env) request = Rack::Request.new(env) + @logger.info("Extracting client ID from request: #{request.params}") # Check various places for client ID client_id = request.params['client_id'] client_id ||= env['HTTP_LAST_EVENT_ID'] @@ -328,12 +344,9 @@ def extract_client_id(env) browser_type = detect_browser_type(user_agent) @logger.info("Client connection from: #{user_agent} (#{browser_type})") - # Handle reconnection - if client_id && @sse_clients.key?(client_id) - handle_client_reconnection(client_id, browser_type) - else + unless client_id # Generate a new client ID if none was provided - client_id ||= SecureRandom.uuid + client_id = SecureRandom.uuid @logger.info("New client connection: #{client_id} (#{browser_type})") end @@ -360,21 +373,6 @@ def detect_browser_type(user_agent) end end - # Handle client reconnection - def handle_client_reconnection(client_id, browser_type) - @logger.info("Client #{client_id} is reconnecting (#{browser_type})") - old_client = @sse_clients[client_id] - begin - old_client[:stream].close if old_client[:stream].respond_to?(:close) && !old_client[:stream].closed? - rescue StandardError => e - @logger.error("Error closing old connection for client #{client_id}: #{e.message}") - end - unregister_sse_client(client_id) - - # Small delay to ensure the old connection is fully closed - sleep 0.1 - end - # Handle SSE with Rack hijacking (e.g., Puma) def handle_rack_hijack_sse(env) client_id = extract_client_id(env) @@ -392,6 +390,8 @@ def handle_rack_hijack_sse(env) end # Set up the SSE connection + # If SSE connection already exists for a client through a different IO, + # it will be closed and a new one will be established def setup_sse_connection(client_id, io, env) # Handle for reconnection, if the client_id is already registered we reuse the mutex # If not a reconnection, generate a new mutex used in registration @@ -399,38 +399,46 @@ def setup_sse_connection(client_id, io, env) mutex = client ? client[:mutex] : Mutex.new # Send headers @logger.debug("Sending HTTP headers for SSE connection #{client_id}") - mutex.synchronize do - io.write("HTTP/1.1 200 OK\r\n") - SSE_HEADERS.each { |k, v| io.write("#{k}: #{v}\r\n") } - io.write("\r\n") - io.flush - end + mutex.synchronize { write_sse_headers(io) } # Register client (will overwrite if already present) register_sse_client(client_id, io, mutex) - # Send an initial comment to keep the connection alive - mutex.synchronize { io.write(": SSE connection established\n\n") } + # Extract query parameters from the request and generate the endpoint + # the client will use to send messages to the server + endpoint = generate_endpoint_info(client_id, env['QUERY_STRING']) + @logger.debug("Sending endpoint information to client #{client_id}: #{endpoint}") + mutex.synchronize { write_sse_initialize(io, endpoint) } + rescue StandardError => e + @logger.error("Error setting up SSE connection for client #{client_id}: #{e.message}") + @logger.error(e.backtrace.join("\n")) if e.backtrace + raise + end - # Extract query parameters from the request - query_string = env['QUERY_STRING'] + def write_sse_headers(stream) + stream.write("HTTP/1.1 200 OK\r\n") - # Send endpoint information as the first message with query parameters + SSE_HEADERS.each { |k, v| stream.write("#{k}: #{v}\r\n") } + stream.write("\r\n") + stream.flush + end + + def generate_endpoint_info(client_id, query_string = '') endpoint = "#{@path_prefix}/#{@messages_route}" - endpoint += "?#{query_string}" if query_string - @logger.debug("Sending endpoint information to client #{client_id}: #{endpoint}") - mutex.synchronize { io.write("event: endpoint\ndata: #{endpoint}\n\n") } + params = [] + params << query_string if query_string && !query_string.empty? + params << "client_id=#{client_id}" + endpoint += "?#{params.join('&')}" + endpoint + end + def write_sse_initialize(stream, endpoint) + stream.write(": SSE connection established\n\n") + stream.write("event: endpoint\ndata: #{endpoint}\n\n") # Send a retry directive with a very short reconnect time # This helps browsers reconnect quickly if the connection is lost - mutex.synchronize do - io.write("retry: 100\n\n") - io.flush - end - rescue StandardError => e - @logger.error("Error setting up SSE connection for client #{client_id}: #{e.message}") - @logger.error(e.backtrace.join("\n")) if e.backtrace - raise + stream.write("retry: 100\n\n") + stream.flush end # Start a keep-alive thread for SSE connection @@ -544,6 +552,8 @@ def process_json_request_with_server(request, server) headers = request.env.select { |k, _v| k.start_with?('HTTP_') } .transform_keys { |k| k.sub('HTTP_', '').downcase.tr('_', '-') } + headers['client_id'] = extract_client_id(request.env) + # Let the specific server handle the JSON request directly response = server.handle_request(body, headers: headers) || [] diff --git a/lib/mcp/transports/stdio_transport.rb b/lib/mcp/transports/stdio_transport.rb index 2941169..cd02f00 100644 --- a/lib/mcp/transports/stdio_transport.rb +++ b/lib/mcp/transports/stdio_transport.rb @@ -43,6 +43,11 @@ def send_message(message) $stdout.flush end + # stdio transport does not support sending to specific clients + def send_message_to(_client_id, message) + send_message(message) + end + private # Send a JSON-RPC error response diff --git a/spec/mcp/server_spec.rb b/spec/mcp/server_spec.rb index 572f780..9f5ba00 100644 --- a/spec/mcp/server_spec.rb +++ b/spec/mcp/server_spec.rb @@ -34,6 +34,8 @@ def call(**_args) end describe '#handle_request' do + let(:client_id) { 'test-client-id' } + let(:headers) { { 'client_id' => client_id } } let(:test_tool_class) do Class.new(FastMcp::Tool) do def self.name @@ -90,8 +92,8 @@ def call(user:) it 'responds with an empty result' do request = { jsonrpc: '2.0', method: 'ping', id: 1 }.to_json - expect(server).to receive(:send_result).with({}, 1) - server.handle_request(request) + expect(server).to receive(:send_result).with(client_id, {}, 1) + server.handle_request(request, headers: headers) end end @@ -118,7 +120,7 @@ def call(user:) it 'responds with the server info' do request = { jsonrpc: '2.0', method: 'initialize', id: 1 }.to_json - expect(server).to receive(:send_result).with({ + expect(server).to receive(:send_result).with(client_id, { protocolVersion: FastMcp::Server::PROTOCOL_VERSION, capabilities: server.capabilities, serverInfo: { @@ -126,7 +128,7 @@ def call(user:) version: server.version } }, 1) - server.handle_request(request) + server.handle_request(request, headers: headers) end end @@ -134,7 +136,7 @@ def call(user:) it 'responds with a list of tools' do request = { jsonrpc: '2.0', method: 'tools/list', id: 1 }.to_json - expect(server).to receive(:send_result) do |result, id| + expect(server).to receive(:send_result) do |_client_id, result, id| expect(id).to eq(1) expect(result[:tools]).to be_an(Array) expect(result[:tools].length).to eq(2) @@ -154,7 +156,7 @@ def call(user:) expect(profile_tool[:inputSchema][:properties][:user][:properties]).to have_key(:last_name) end - server.handle_request(request) + server.handle_request(request, headers: headers) end end @@ -171,11 +173,12 @@ def call(user:) }.to_json expect(server).to receive(:send_result).with( + client_id, { content: [{ text: 'Hello, World!', type: 'text' }], isError: false }, 1, metadata: {} ) - server.handle_request(request) + server.handle_request(request, headers: headers) end it 'calls a tool with nested properties' do @@ -195,11 +198,12 @@ def call(user:) }.to_json expect(server).to receive(:send_result).with( + client_id, { content: [{ text: 'John Doe', type: 'text' }], isError: false }, 1, metadata: {} ) - server.handle_request(request) + server.handle_request(request, headers: headers) end it "returns an error if the tool doesn't exist" do @@ -213,8 +217,8 @@ def call(user:) id: 1 }.to_json - expect(server).to receive(:send_error).with(-32_602, 'Tool not found: non-existent-tool', 1) - server.handle_request(request) + expect(server).to receive(:send_error).with(-32_602, client_id, 'Tool not found: non-existent-tool', 1) + server.handle_request(request, headers: headers) end it 'returns an error if the tool name is missing' do @@ -227,8 +231,8 @@ def call(user:) id: 1 }.to_json - expect(server).to receive(:send_error).with(-32_602, 'Invalid params: missing tool name', 1) - server.handle_request(request) + expect(server).to receive(:send_error).with(-32_602, client_id, 'Invalid params: missing tool name', 1) + server.handle_request(request, headers: headers) end end @@ -236,22 +240,22 @@ def call(user:) it 'returns an error for an unknown method' do request = { jsonrpc: '2.0', method: 'unknown', id: 1 }.to_json - expect(server).to receive(:send_error).with(-32_601, 'Method not found: unknown', 1) - server.handle_request(request) + expect(server).to receive(:send_error).with(-32_601, client_id, 'Method not found: unknown', 1) + server.handle_request(request, headers: headers) end it 'returns an error for an invalid JSON-RPC request' do request = { id: 1 }.to_json - expect(server).to receive(:send_error).with(-32_600, 'Invalid Request', 1) - server.handle_request(request) + expect(server).to receive(:send_error).with(-32_600, client_id, 'Invalid Request', 1) + server.handle_request(request, headers: headers) end it 'returns an error for an invalid JSON request' do request = 'invalid json' - expect(server).to receive(:send_error).with(-32_600, 'Invalid Request', nil) - server.handle_request(request) + expect(server).to receive(:send_error).with(-32_600, client_id, 'Invalid Request', nil) + server.handle_request(request, headers: headers) end end end diff --git a/spec/mcp/transports/rack_transport_spec.rb b/spec/mcp/transports/rack_transport_spec.rb index 29b7ec5..67f7e1f 100644 --- a/spec/mcp/transports/rack_transport_spec.rb +++ b/spec/mcp/transports/rack_transport_spec.rb @@ -132,8 +132,10 @@ # Add a mock SSE client that raises an error client_stream = double('stream') expect(client_stream).to receive(:respond_to?).with(:closed?).and_return(true) + expect(client_stream).to receive(:respond_to?).with(:close).and_return(true) expect(client_stream).to receive(:closed?).and_return(false) expect(client_stream).to receive(:write).and_raise(StandardError.new('Test error')) + expect(client_stream).to receive(:close) transport.instance_variable_set(:@sse_clients, { 'test-client' => { stream: client_stream, mutex: Mutex.new } }) @@ -156,6 +158,7 @@ allow(client_stream).to receive(:closed?).and_return(false) allow(client_stream).to receive(:write) allow(client_stream).to receive(:flush) + allow(client_stream).to receive(:close) # Create a client with a mutex that will raise an error client_mutex = double('mutex')