diff --git a/cmd/emcee/main.go b/cmd/emcee/main.go index 1440f93..f208f10 100644 --- a/cmd/emcee/main.go +++ b/cmd/emcee/main.go @@ -16,6 +16,7 @@ import ( "github.com/spf13/cobra" "golang.org/x/sync/errgroup" + "github.com/loopwork-ai/emcee/internal" "github.com/loopwork-ai/emcee/mcp" ) @@ -73,14 +74,31 @@ The spec-path-or-url argument can be: return retryablehttp.DefaultBackoff(min, max, attemptNum, resp) } } - client := retryClient.StandardClient() - opts = append(opts, mcp.WithClient(client)) - // Set Authentication header if provided + // Set default headers if auth is provided if auth != "" { - opts = append(opts, mcp.WithAuth(auth)) + parts := strings.SplitN(auth, " ", 2) + if len(parts) == 1 { + // Only token provided, add Bearer prefix + logger.Warn("no auth scheme provided, automatically adding 'Bearer' prefix") + auth = "Bearer " + parts[0] + } else if len(parts) == 2 { + // Scheme and token provided, use as-is + auth = fmt.Sprintf("%s %s", parts[0], parts[1]) + } + + headers := http.Header{} + headers.Add("Authorization", auth) + + retryClient.HTTPClient.Transport = &internal.HeaderTransport{ + Base: retryClient.HTTPClient.Transport, + Headers: headers, + } } + client := retryClient.StandardClient() + opts = append(opts, mcp.WithClient(client)) + // Read OpenAPI specification data var rpcInput io.Reader = os.Stdin var specData []byte diff --git a/internal/http.go b/internal/http.go new file mode 100644 index 0000000..41b9257 --- /dev/null +++ b/internal/http.go @@ -0,0 +1,22 @@ +package internal + +import "net/http" + +// HeaderTransport is a custom RoundTripper that adds default headers to requests +type HeaderTransport struct { + Base http.RoundTripper + Headers http.Header +} + +func (t *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { + for key, values := range t.Headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + base := t.Base + if base == nil { + base = http.DefaultTransport + } + return base.RoundTrip(req) +} diff --git a/mcp/server.go b/mcp/server.go index 0013859..5973739 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -21,14 +21,6 @@ import ( // ServerOption configures a Server type ServerOption func(*Server) error -// WithAuth sets the HTTP Authorization header -func WithAuth(auth string) ServerOption { - return func(s *Server) error { - s.authHeader = auth - return nil - } -} - // WithClient sets the HTTP client func WithClient(client *http.Client) ServerOption { return func(s *Server) error { @@ -88,19 +80,20 @@ func WithLogger(logger *slog.Logger) ServerOption { // Server represents an MCP server that processes JSON-RPC requests type Server struct { - doc libopenapi.Document - model *v3.Document - baseURL string - client *http.Client - info ServerInfo - authHeader string - logger *slog.Logger + doc libopenapi.Document + model *v3.Document + baseURL string + client *http.Client + info ServerInfo + logger *slog.Logger } // NewServer creates a new MCP server instance func NewServer(opts ...ServerOption) (*Server, error) { s := &Server{ - client: http.DefaultClient, + client: &http.Client{ + Transport: http.DefaultTransport, + }, } // Apply options @@ -441,9 +434,6 @@ func (s *Server) handleToolsCall(request *ToolCallRequest) (*ToolCallResponse, e if reqBody != nil { req.Header.Set("Content-Type", "application/json") } - if s.authHeader != "" { - req.Header.Set("Authorization", s.authHeader) - } // Send request resp, err := s.client.Do(req) @@ -468,6 +458,7 @@ func (s *Server) handleToolsCall(request *ToolCallRequest) (*ToolCallResponse, e contentType := resp.Header.Get("Content-Type") var content Content + // Create content based on response content type if strings.HasPrefix(contentType, "image/") { encoded := base64.StdEncoding.EncodeToString(body) content = NewImageContent(encoded, contentType, []Role{RoleAssistant}, nil) diff --git a/mcp/server_test.go b/mcp/server_test.go index bc9996b..26d8873 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "testing" + "github.com/loopwork-ai/emcee/internal" "github.com/loopwork-ai/emcee/jsonrpc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -131,8 +132,17 @@ func setupTestServer(t *testing.T) (*Server, *httptest.Server) { // Create a small test image imgData := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} // PNG header + // Track if auth header was checked + authHeaderChecked := false + var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify auth header if present + if authHeader := r.Header.Get("Authorization"); authHeader != "" { + assert.Equal(t, "Bearer test-token", authHeader, "Authorization header should match") + authHeaderChecked = true + } + switch r.URL.Path { case "/openapi.json": w.Header().Set("Content-Type", "application/json") @@ -147,6 +157,11 @@ func setupTestServer(t *testing.T) (*Server, *httptest.Server) { assert.Equal(t, "5", limit) assert.Equal(t, "dog", petType) + // For auth test case, verify the auth header was checked + if r.Header.Get("Authorization") != "" { + assert.True(t, authHeaderChecked, "Auth header should have been checked") + } + pets := []map[string]interface{}{ {"id": 1, "name": "Fluffy", "type": "dog"}, {"id": 2, "name": "Rover", "type": "dog"}, @@ -197,6 +212,12 @@ func setupTestServer(t *testing.T) (*Server, *httptest.Server) { } })) + client := ts.Client() + client.Transport = &internal.HeaderTransport{ + Base: client.Transport, + Headers: http.Header{"Authorization": []string{"Bearer test-token"}}, + } + // Create a server instance with the test server URL and spec server, err := NewServer( WithClient(ts.Client()), @@ -339,16 +360,21 @@ func TestHandleToolsCall(t *testing.T) { server, ts := setupTestServer(t) defer ts.Close() + // Test with auth header + serverWithAuth, _ := setupTestServer(t) + // Create a small test image imgData := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} // PNG header tests := []struct { name string + server *Server request jsonrpc.Request validate func(*testing.T, jsonrpc.Response) }{ { name: "GET request with query parameters", + server: server, request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "listPets", "arguments": {"limit": 5, "type": "dog"}}`), 1), validate: func(t *testing.T, response jsonrpc.Response) { assert.Equal(t, "2.0", response.Version) @@ -391,6 +417,7 @@ func TestHandleToolsCall(t *testing.T) { }, { name: "POST request with body parameters", + server: server, request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "createPet", "arguments": {"name": "Whiskers", "age": 5}}`), 2), validate: func(t *testing.T, response jsonrpc.Response) { assert.Equal(t, "2.0", response.Version) @@ -428,6 +455,7 @@ func TestHandleToolsCall(t *testing.T) { }, { name: "GET image request", + server: server, request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "getPetImage"}`), 3), validate: func(t *testing.T, response jsonrpc.Response) { assert.Equal(t, "2.0", response.Version) @@ -464,6 +492,7 @@ func TestHandleToolsCall(t *testing.T) { }, { name: "Request with invalid operationId", + server: server, request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "nonexistentOperation"}`), 4), validate: func(t *testing.T, response jsonrpc.Response) { assert.Equal(t, "2.0", response.Version) @@ -474,6 +503,7 @@ func TestHandleToolsCall(t *testing.T) { }, { name: "GET request with URL escaped parameters", + server: server, request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "getPet", "arguments": {"petId": "special pet"}}`), 5), validate: func(t *testing.T, response jsonrpc.Response) { assert.Equal(t, "2.0", response.Version) @@ -504,11 +534,58 @@ func TestHandleToolsCall(t *testing.T) { assert.Equal(t, "Special Pet", pet["name"]) }, }, + { + name: "Request with auth header", + server: serverWithAuth, + request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "listPets", "arguments": {"limit": 5, "type": "dog"}}`), 6), + validate: func(t *testing.T, response jsonrpc.Response) { + assert.Equal(t, "2.0", response.Version) + assert.Equal(t, 6, response.ID.Value()) + assert.Nil(t, response.Error) + + var result ToolCallResponse + resultBytes, err := json.Marshal(response.Result) + require.NoError(t, err) + err = json.Unmarshal(resultBytes, &result) + require.NoError(t, err) + + assert.Len(t, result.Content, 1) + assert.False(t, result.IsError) + + // Verify the response content + content := result.Content[0] + assert.Equal(t, "text", content.Type) + assert.NotNil(t, content.Annotations) + assert.Contains(t, content.Annotations.Audience, RoleAssistant) + + var textContent Content + contentBytes, err := json.Marshal(content) + assert.NoError(t, err) + err = json.Unmarshal(contentBytes, &textContent) + assert.NoError(t, err) + + var pets []interface{} + err = json.Unmarshal([]byte(textContent.Text), &pets) + assert.NoError(t, err) + assert.Len(t, pets, 2) + + // Verify the returned pets + for _, pet := range pets { + petMap := pet.(map[string]interface{}) + assert.Equal(t, "dog", petMap["type"]) + } + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - response := server.HandleRequest(tt.request) + var response jsonrpc.Response + if tt.server != nil { + response = tt.server.HandleRequest(tt.request) + } else { + response = server.HandleRequest(tt.request) + } tt.validate(t, response) }) }