diff --git a/pkg/mcp/parser.go b/pkg/mcp/parser.go index f7018f1f8..3ac11628d 100644 --- a/pkg/mcp/parser.go +++ b/pkg/mcp/parser.go @@ -7,6 +7,7 @@ import ( "encoding/json" "io" "net/http" + "strconv" "strings" "golang.org/x/exp/jsonrpc2" @@ -135,18 +136,25 @@ func parseMCPRequest(bodyBytes []byte) *ParsedMCPRequest { return nil } - // Handle only request messages + // Handle only request messages (both calls with ID and notifications without ID) req, ok := msg.(*jsonrpc2.Request) if !ok { + // Response or error messages are not parsed here return nil } // Extract resource ID and arguments based on the method resourceID, arguments := extractResourceAndArguments(req.Method, req.Params) + // Determine the ID - will be nil for notifications + var id interface{} + if req.ID.IsValid() { + id = req.ID.Raw() + } + return &ParsedMCPRequest{ Method: req.Method, - ID: req.ID.Raw(), + ID: id, Params: req.Params, ResourceID: resourceID, Arguments: arguments, @@ -162,24 +170,36 @@ type methodHandler func(map[string]interface{}) (string, map[string]interface{}) // methodHandlers maps MCP methods to their respective handlers var methodHandlers = map[string]methodHandler{ - "initialize": handleInitializeMethod, - "tools/call": handleNamedResourceMethod, - "prompts/get": handleNamedResourceMethod, - "resources/read": handleResourceReadMethod, - "resources/list": handleListMethod, - "tools/list": handleListMethod, - "prompts/list": handleListMethod, - "progress/update": handleProgressMethod, - "notifications/message": handleNotificationMethod, - "logging/setLevel": handleLoggingMethod, - "completion/complete": handleCompletionMethod, + "initialize": handleInitializeMethod, + "tools/call": handleNamedResourceMethod, + "prompts/get": handleNamedResourceMethod, + "resources/read": handleResourceReadMethod, + "resources/list": handleListMethod, + "tools/list": handleListMethod, + "prompts/list": handleListMethod, + "progress/update": handleProgressMethod, + "notifications/message": handleNotificationMethod, + "logging/setLevel": handleLoggingMethod, + "completion/complete": handleCompletionMethod, + "elicitation/create": handleElicitationMethod, + "sampling/createMessage": handleSamplingMethod, + "resources/subscribe": handleResourceSubscribeMethod, + "resources/unsubscribe": handleResourceUnsubscribeMethod, + "resources/templates/list": handleListMethod, + "roots/list": handleListMethod, + "notifications/progress": handleProgressNotificationMethod, + "notifications/cancelled": handleCancelledNotificationMethod, } // staticResourceIDs maps methods to their static resource IDs var staticResourceIDs = map[string]string{ - "ping": "ping", - "notifications/roots/list_changed": "roots", - "notifications/initialized": "initialized", + "ping": "ping", + "notifications/roots/list_changed": "roots", + "notifications/initialized": "initialized", + "notifications/prompts/list_changed": "prompts", + "notifications/resources/list_changed": "resources", + "notifications/resources/updated": "resources", + "notifications/tools/list_changed": "tools", } func extractResourceAndArguments(method string, params json.RawMessage) (string, map[string]interface{}) { @@ -277,14 +297,114 @@ func handleLoggingMethod(paramsMap map[string]interface{}) (string, map[string]i return "", nil } -// handleCompletionMethod extracts resource ID for completion requests +// handleCompletionMethod extracts resource ID for completion requests. +// For PromptReference: extracts the prompt name +// For ResourceTemplateReference: extracts the template URI +// For legacy string ref: returns the string value +// Always returns paramsMap as arguments since completion requests need the full context +// including the argument being completed and any context from previous completions. func handleCompletionMethod(paramsMap map[string]interface{}) (string, map[string]interface{}) { + // Check if ref is a map (PromptReference or ResourceTemplateReference) + if ref, ok := paramsMap["ref"].(map[string]interface{}); ok { + // Try to extract name for PromptReference + if name, ok := ref["name"].(string); ok { + return name, paramsMap + } + // Try to extract uri for ResourceTemplateReference + if uri, ok := ref["uri"].(string); ok { + return uri, paramsMap + } + } + // Fallback to string ref (legacy support) if ref, ok := paramsMap["ref"].(string); ok { - return ref, nil + return ref, paramsMap + } + return "", paramsMap +} + +// handleElicitationMethod extracts resource ID for elicitation requests +func handleElicitationMethod(paramsMap map[string]interface{}) (string, map[string]interface{}) { + // The message field could be used as a resource identifier + if message, ok := paramsMap["message"].(string); ok { + return message, paramsMap + } + return "", paramsMap +} + +// handleSamplingMethod extracts resource ID for sampling/createMessage requests. +// Returns the model name from modelPreferences if available, otherwise returns a +// truncated version of the systemPrompt. The 50-character truncation provides a +// reasonable balance between uniqueness and readability for authorization and audit logs. +func handleSamplingMethod(paramsMap map[string]interface{}) (string, map[string]interface{}) { + // Use model preferences or system prompt as identifier if available + if modelPrefs, ok := paramsMap["modelPreferences"].(map[string]interface{}); ok && modelPrefs != nil { + // Try direct name field first (simplified structure) + if name, ok := modelPrefs["name"].(string); ok && name != "" { + return name, paramsMap + } + // Try to get model name from hints array (full spec structure) + if hints, ok := modelPrefs["hints"].([]interface{}); ok && len(hints) > 0 { + if hint, ok := hints[0].(map[string]interface{}); ok { + if name, ok := hint["name"].(string); ok && name != "" { + return name, paramsMap + } + } + } + } + if systemPrompt, ok := paramsMap["systemPrompt"].(string); ok && systemPrompt != "" { + // Use first 50 chars of system prompt as identifier + // This provides a reasonable balance between uniqueness and readability + if len(systemPrompt) > 50 { + return systemPrompt[:50], paramsMap + } + return systemPrompt, paramsMap + } + return "", paramsMap +} + +// handleResourceSubscribeMethod extracts resource ID for resource subscribe operations +func handleResourceSubscribeMethod(paramsMap map[string]interface{}) (string, map[string]interface{}) { + if uri, ok := paramsMap["uri"].(string); ok { + return uri, nil } return "", nil } +// handleResourceUnsubscribeMethod extracts resource ID for resource unsubscribe operations +func handleResourceUnsubscribeMethod(paramsMap map[string]interface{}) (string, map[string]interface{}) { + if uri, ok := paramsMap["uri"].(string); ok { + return uri, nil + } + return "", nil +} + +// handleProgressNotificationMethod extracts resource ID for progress notifications. +// Extracts the progressToken which can be either a string or numeric value. +func handleProgressNotificationMethod(paramsMap map[string]interface{}) (string, map[string]interface{}) { + if token, ok := paramsMap["progressToken"].(string); ok { + return token, paramsMap + } + // Also handle numeric progress tokens + if token, ok := paramsMap["progressToken"].(float64); ok { + return strconv.FormatFloat(token, 'f', 0, 64), paramsMap + } + return "", paramsMap +} + +// handleCancelledNotificationMethod extracts resource ID for cancelled notifications. +// Extracts the requestId which can be either a string or numeric value. +func handleCancelledNotificationMethod(paramsMap map[string]interface{}) (string, map[string]interface{}) { + // Extract request ID as the resource identifier + if requestId, ok := paramsMap["requestId"].(string); ok { + return requestId, paramsMap + } + // Handle numeric request IDs + if requestId, ok := paramsMap["requestId"].(float64); ok { + return strconv.FormatFloat(requestId, 'f', 0, 64), paramsMap + } + return "", paramsMap +} + // GetMCPMethod is a convenience function to get the MCP method from the context. func GetMCPMethod(ctx context.Context) string { if parsed := GetParsedMCPRequest(ctx); parsed != nil { diff --git a/pkg/mcp/parser_test.go b/pkg/mcp/parser_test.go index d906ea1e4..c21d26923 100644 --- a/pkg/mcp/parser_test.go +++ b/pkg/mcp/parser_test.go @@ -287,6 +287,377 @@ func TestExtractResourceAndArguments(t *testing.T) { expectedResourceID: "", expectedArguments: nil, }, + { + name: "elicitation/create with message", + method: "elicitation/create", + params: `{"message":"Please provide your API key","requestedSchema":{"type":"object","properties":{"apiKey":{"type":"string"}}}}`, + expectedResourceID: "Please provide your API key", + expectedArguments: map[string]interface{}{ + "message": "Please provide your API key", + "requestedSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "apiKey": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + { + name: "sampling/createMessage with model preferences", + method: "sampling/createMessage", + params: `{"modelPreferences":{"name":"gpt-4"},"messages":[{"role":"user","content":{"type":"text","text":"Hello"}}],"maxTokens":100}`, + expectedResourceID: "gpt-4", + expectedArguments: map[string]interface{}{ + "modelPreferences": map[string]interface{}{ + "name": "gpt-4", + }, + "messages": []interface{}{ + map[string]interface{}{ + "role": "user", + "content": map[string]interface{}{ + "type": "text", + "text": "Hello", + }, + }, + }, + "maxTokens": float64(100), + }, + }, + { + name: "sampling/createMessage with system prompt", + method: "sampling/createMessage", + params: `{"systemPrompt":"You are a helpful assistant","messages":[],"maxTokens":100}`, + expectedResourceID: "You are a helpful assistant", + expectedArguments: map[string]interface{}{ + "systemPrompt": "You are a helpful assistant", + "messages": []interface{}{}, + "maxTokens": float64(100), + }, + }, + { + name: "resources/subscribe with URI", + method: "resources/subscribe", + params: `{"uri":"file:///watched.txt"}`, + expectedResourceID: "file:///watched.txt", + expectedArguments: nil, + }, + { + name: "resources/unsubscribe with URI", + method: "resources/unsubscribe", + params: `{"uri":"file:///unwatched.txt"}`, + expectedResourceID: "file:///unwatched.txt", + expectedArguments: nil, + }, + { + name: "resources/templates/list with cursor", + method: "resources/templates/list", + params: `{"cursor":"page-2"}`, + expectedResourceID: "page-2", + expectedArguments: nil, + }, + { + name: "roots/list empty params", + method: "roots/list", + params: `{}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "notifications/progress with string token", + method: "notifications/progress", + params: `{"progressToken":"task-456","progress":75,"total":100}`, + expectedResourceID: "task-456", + expectedArguments: map[string]interface{}{ + "progressToken": "task-456", + "progress": float64(75), + "total": float64(100), + }, + }, + { + name: "notifications/progress with numeric token", + method: "notifications/progress", + params: `{"progressToken":123,"progress":50}`, + expectedResourceID: "123", + expectedArguments: map[string]interface{}{ + "progressToken": float64(123), + "progress": float64(50), + }, + }, + { + name: "notifications/cancelled with string requestId", + method: "notifications/cancelled", + params: `{"requestId":"req-789","reason":"User cancelled"}`, + expectedResourceID: "req-789", + expectedArguments: map[string]interface{}{ + "requestId": "req-789", + "reason": "User cancelled", + }, + }, + { + name: "notifications/cancelled with numeric requestId", + method: "notifications/cancelled", + params: `{"requestId":456}`, + expectedResourceID: "456", + expectedArguments: map[string]interface{}{ + "requestId": float64(456), + }, + }, + { + name: "completion/complete with PromptReference", + method: "completion/complete", + params: `{"ref":{"type":"ref/prompt","name":"greeting"},"argument":{"name":"user","value":"Alice"}}`, + expectedResourceID: "greeting", + expectedArguments: map[string]interface{}{ + "ref": map[string]interface{}{ + "type": "ref/prompt", + "name": "greeting", + }, + "argument": map[string]interface{}{ + "name": "user", + "value": "Alice", + }, + }, + }, + { + name: "completion/complete with ResourceTemplateReference", + method: "completion/complete", + params: `{"ref":{"type":"ref/resource","uri":"template://example"},"argument":{"name":"param","value":"test"}}`, + expectedResourceID: "template://example", + expectedArguments: map[string]interface{}{ + "ref": map[string]interface{}{ + "type": "ref/resource", + "uri": "template://example", + }, + "argument": map[string]interface{}{ + "name": "param", + "value": "test", + }, + }, + }, + { + name: "notifications/prompts/list_changed", + method: "notifications/prompts/list_changed", + params: `{}`, + expectedResourceID: "prompts", + expectedArguments: nil, + }, + { + name: "notifications/resources/list_changed", + method: "notifications/resources/list_changed", + params: `{}`, + expectedResourceID: "resources", + expectedArguments: nil, + }, + { + name: "notifications/resources/updated", + method: "notifications/resources/updated", + params: `{"uri":"file:///updated.txt"}`, + expectedResourceID: "resources", + expectedArguments: nil, + }, + { + name: "notifications/tools/list_changed", + method: "notifications/tools/list_changed", + params: `{}`, + expectedResourceID: "tools", + expectedArguments: nil, + }, + // Edge cases and additional coverage + { + name: "empty params for method with handler", + method: "tools/call", + params: `{}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "null params", + method: "tools/call", + params: `null`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "resources/read with empty uri", + method: "resources/read", + params: `{"uri":""}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "resources/read with missing uri", + method: "resources/read", + params: `{"other":"value"}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "progress/update with missing token", + method: "progress/update", + params: `{"progress":50}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "logging/setLevel with missing level", + method: "logging/setLevel", + params: `{"other":"value"}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "notifications/message with method field", + method: "notifications/message", + params: `{"method":"test-method","data":"test"}`, + expectedResourceID: "test-method", + expectedArguments: nil, + }, + { + name: "notifications/message without method field", + method: "notifications/message", + params: `{"data":"test"}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "elicitation/create without message", + method: "elicitation/create", + params: `{"requestedSchema":{"type":"object"}}`, + expectedResourceID: "", + expectedArguments: map[string]interface{}{ + "requestedSchema": map[string]interface{}{ + "type": "object", + }, + }, + }, + { + name: "sampling/createMessage without preferences or prompt", + method: "sampling/createMessage", + params: `{"messages":[],"maxTokens":100}`, + expectedResourceID: "", + expectedArguments: map[string]interface{}{ + "messages": []interface{}{}, + "maxTokens": float64(100), + }, + }, + { + name: "sampling/createMessage with long system prompt", + method: "sampling/createMessage", + params: `{"systemPrompt":"This is a very long system prompt that exceeds fifty characters and should be truncated","messages":[],"maxTokens":100}`, + expectedResourceID: "This is a very long system prompt that exceeds fif", + expectedArguments: map[string]interface{}{ + "systemPrompt": "This is a very long system prompt that exceeds fifty characters and should be truncated", + "messages": []interface{}{}, + "maxTokens": float64(100), + }, + }, + { + name: "resources/subscribe with missing uri", + method: "resources/subscribe", + params: `{"other":"value"}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "resources/unsubscribe with missing uri", + method: "resources/unsubscribe", + params: `{"other":"value"}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "completion/complete with legacy string ref", + method: "completion/complete", + params: `{"ref":"legacy-ref","argument":{"name":"test","value":"val"}}`, + expectedResourceID: "legacy-ref", + expectedArguments: map[string]interface{}{ + "ref": "legacy-ref", + "argument": map[string]interface{}{ + "name": "test", + "value": "val", + }, + }, + }, + { + name: "completion/complete with invalid ref type", + method: "completion/complete", + params: `{"ref":123,"argument":{"name":"test","value":"val"}}`, + expectedResourceID: "", + expectedArguments: map[string]interface{}{ + "ref": float64(123), + "argument": map[string]interface{}{"name": "test", "value": "val"}, + }, + }, + { + name: "completion/complete with ref missing name and uri", + method: "completion/complete", + params: `{"ref":{"type":"ref/prompt"},"argument":{"name":"test","value":"val"}}`, + expectedResourceID: "", + expectedArguments: map[string]interface{}{ + "ref": map[string]interface{}{ + "type": "ref/prompt", + }, + "argument": map[string]interface{}{ + "name": "test", + "value": "val", + }, + }, + }, + { + name: "notifications/progress with missing progressToken", + method: "notifications/progress", + params: `{"progress":50}`, + expectedResourceID: "", + expectedArguments: map[string]interface{}{ + "progress": float64(50), + }, + }, + { + name: "notifications/cancelled with missing requestId", + method: "notifications/cancelled", + params: `{"reason":"User cancelled"}`, + expectedResourceID: "", + expectedArguments: map[string]interface{}{ + "reason": "User cancelled", + }, + }, + { + name: "tools/list with empty cursor", + method: "tools/list", + params: `{"cursor":""}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "prompts/list with empty cursor", + method: "prompts/list", + params: `{"cursor":""}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "resources/list with empty cursor", + method: "resources/list", + params: `{"cursor":""}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "resources/templates/list with empty cursor", + method: "resources/templates/list", + params: `{"cursor":""}`, + expectedResourceID: "", + expectedArguments: nil, + }, + { + name: "roots/list with cursor", + method: "roots/list", + params: `{"cursor":"page-2"}`, + expectedResourceID: "page-2", + expectedArguments: nil, + }, } for _, tt := range tests { @@ -497,3 +868,353 @@ func TestMiddlewarePreservesRequestBody(t *testing.T) { // Verify the request body was preserved for the next handler assert.Equal(t, originalBody, capturedBody) } + +func TestParsingMiddlewareErrorHandling(t *testing.T) { + t.Parallel() + tests := []struct { + name string + method string + path string + contentType string + body io.Reader + expectParsed bool + }{ + { + name: "body read error simulation", + method: "POST", + path: "/messages", + contentType: "application/json", + body: &errorReader{}, + expectParsed: false, + }, + { + name: "empty body", + method: "POST", + path: "/messages", + contentType: "application/json", + body: bytes.NewBufferString(""), + expectParsed: false, + }, + { + name: "malformed JSON", + method: "POST", + path: "/messages", + contentType: "application/json", + body: bytes.NewBufferString(`{"invalid json`), + expectParsed: false, + }, + { + name: "JSON array instead of object", + method: "POST", + path: "/messages", + contentType: "application/json", + body: bytes.NewBufferString(`[{"jsonrpc":"2.0"}]`), + expectParsed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Create a test handler that captures the context + var capturedCtx context.Context + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedCtx = r.Context() + w.WriteHeader(http.StatusOK) + }) + + // Wrap with parsing middleware + middleware := ParsingMiddleware(testHandler) + + // Create test request + req := httptest.NewRequest(tt.method, tt.path, tt.body) + req.Header.Set("Content-Type", tt.contentType) + w := httptest.NewRecorder() + + // Execute the middleware + middleware.ServeHTTP(w, req) + + // Check if parsing occurred as expected + parsed := GetParsedMCPRequest(capturedCtx) + if tt.expectParsed { + assert.NotNil(t, parsed) + } else { + assert.Nil(t, parsed) + } + }) + } +} + +// errorReader simulates an io.Reader that always returns an error +type errorReader struct{} + +func (*errorReader) Read(_ []byte) (n int, err error) { + return 0, io.ErrUnexpectedEOF +} + +func TestExtractResourceAndArgumentsNilParams(t *testing.T) { + t.Parallel() + tests := []struct { + name string + method string + expectedResourceID string + }{ + { + name: "method with static resource ID", + method: "ping", + expectedResourceID: "ping", + }, + { + name: "method without handler or static ID", + method: "unknown/method", + expectedResourceID: "", + }, + { + name: "notifications/initialized", + method: "notifications/initialized", + expectedResourceID: "initialized", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + resourceID, arguments := extractResourceAndArguments(tt.method, nil) + assert.Equal(t, tt.expectedResourceID, resourceID) + assert.Nil(t, arguments) + }) + } +} + +func TestParsingMiddlewareWithBatchRequests(t *testing.T) { + t.Parallel() + // Test batch JSON-RPC requests (currently not supported but should not crash) + batchBody := `[ + {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"tool1"}}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"tool2"}} + ]` + + var capturedCtx context.Context + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedCtx = r.Context() + w.WriteHeader(http.StatusOK) + }) + + middleware := ParsingMiddleware(testHandler) + req := httptest.NewRequest("POST", "/messages", bytes.NewBufferString(batchBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + middleware.ServeHTTP(w, req) + + // Batch requests should not be parsed (not supported yet) + parsed := GetParsedMCPRequest(capturedCtx) + assert.Nil(t, parsed) +} + +func TestConvenienceFunctionsWithNilContext(t *testing.T) { + t.Parallel() + // Test convenience functions with nil parsed request + ctx := context.Background() + + assert.Equal(t, "", GetMCPMethod(ctx)) + assert.Equal(t, "", GetMCPResourceID(ctx)) + assert.Nil(t, GetMCPArguments(ctx)) +} + +func TestHandlerFunctionsEdgeCases(t *testing.T) { + t.Parallel() + tests := []struct { + name string + handler func(map[string]interface{}) (string, map[string]interface{}) + params map[string]interface{} + expectedID string + checkArgs bool + }{ + { + name: "handleInitializeMethod with missing clientInfo", + handler: handleInitializeMethod, + params: map[string]interface{}{ + "protocolVersion": "2024-11-05", + }, + expectedID: "", + checkArgs: true, + }, + { + name: "handleInitializeMethod with non-map clientInfo", + handler: handleInitializeMethod, + params: map[string]interface{}{ + "clientInfo": "not-a-map", + }, + expectedID: "", + checkArgs: true, + }, + { + name: "handleInitializeMethod with clientInfo missing name", + handler: handleInitializeMethod, + params: map[string]interface{}{ + "clientInfo": map[string]interface{}{ + "version": "1.0.0", + }, + }, + expectedID: "", + checkArgs: true, + }, + { + name: "handleNamedResourceMethod with non-string name", + handler: handleNamedResourceMethod, + params: map[string]interface{}{ + "name": 123, + }, + expectedID: "", + checkArgs: false, + }, + { + name: "handleNamedResourceMethod with non-map arguments", + handler: handleNamedResourceMethod, + params: map[string]interface{}{ + "name": "test", + "arguments": "not-a-map", + }, + expectedID: "test", + checkArgs: false, + }, + { + name: "handleSamplingMethod with non-map modelPreferences", + handler: handleSamplingMethod, + params: map[string]interface{}{ + "modelPreferences": "not-a-map", + }, + expectedID: "", + checkArgs: true, + }, + { + name: "handleSamplingMethod with modelPreferences missing name", + handler: handleSamplingMethod, + params: map[string]interface{}{ + "modelPreferences": map[string]interface{}{ + "speedPriority": 1, + }, + }, + expectedID: "", + checkArgs: true, + }, + { + name: "handleProgressNotificationMethod with invalid numeric token", + handler: handleProgressNotificationMethod, + params: map[string]interface{}{ + "progressToken": "not-a-number", + }, + expectedID: "not-a-number", + checkArgs: true, + }, + { + name: "handleCancelledNotificationMethod with invalid numeric requestId", + handler: handleCancelledNotificationMethod, + params: map[string]interface{}{ + "requestId": "not-a-number", + }, + expectedID: "not-a-number", + checkArgs: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + resourceID, args := tt.handler(tt.params) + assert.Equal(t, tt.expectedID, resourceID) + if tt.checkArgs { + assert.Equal(t, tt.params, args) + } + }) + } +} + +func TestParsingMiddlewareIntegration(t *testing.T) { + t.Parallel() + // Test that the middleware correctly integrates with a full request/response cycle + tests := []struct { + name string + body string + expectedMethod string + expectedResourceID string + expectedArguments map[string]interface{} + }{ + { + name: "complex nested parameters", + body: `{ + "jsonrpc": "2.0", + "id": "complex-1", + "method": "tools/call", + "params": { + "name": "complex_tool", + "arguments": { + "nested": { + "deep": { + "value": "test" + } + }, + "array": [1, 2, 3], + "boolean": true, + "null": null + } + } + }`, + expectedMethod: "tools/call", + expectedResourceID: "complex_tool", + expectedArguments: map[string]interface{}{ + "nested": map[string]interface{}{ + "deep": map[string]interface{}{ + "value": "test", + }, + }, + "array": []interface{}{float64(1), float64(2), float64(3)}, + "boolean": true, + "null": nil, + }, + }, + { + name: "JSON-RPC notification (no id)", + body: `{ + "jsonrpc": "2.0", + "method": "notifications/message", + "params": { + "method": "log", + "level": "info", + "message": "test" + } + }`, + expectedMethod: "notifications/message", + expectedResourceID: "log", + expectedArguments: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var parsed *ParsedMCPRequest + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + parsed = GetParsedMCPRequest(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + middleware := ParsingMiddleware(testHandler) + req := httptest.NewRequest("POST", "/messages", bytes.NewBufferString(tt.body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + middleware.ServeHTTP(w, req) + + if tt.expectedMethod != "" { + require.NotNil(t, parsed) + assert.Equal(t, tt.expectedMethod, parsed.Method) + assert.Equal(t, tt.expectedResourceID, parsed.ResourceID) + assert.Equal(t, tt.expectedArguments, parsed.Arguments) + } else { + assert.Nil(t, parsed) + } + }) + } +}