Skip to content

Commit

Permalink
Fix how requests are constructed from server base URL (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt authored Jan 26, 2025
1 parent 9dcdd8c commit 9676728
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 4 deletions.
28 changes: 24 additions & 4 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log/slog"
"net/http"
"net/url"
"path"
"reflect"
"strings"

Expand Down Expand Up @@ -346,7 +347,7 @@ func (s *Server) handleToolsList(request *ToolsListRequest) (*ToolsListResponse,
}

func (s *Server) handleToolsCall(request *ToolCallRequest) (*ToolCallResponse, error) {
method, path, operation, pathItem, found := s.findOperation(s.model, request.Name)
method, p, operation, pathItem, found := s.findOperation(s.model, request.Name)
if !found {
return nil, jsonrpc.NewError(jsonrpc.ErrMethodNotFound, nil)
}
Expand All @@ -357,13 +358,32 @@ func (s *Server) handleToolsCall(request *ToolCallRequest) (*ToolCallResponse, e
return nil, jsonrpc.NewError(jsonrpc.ErrInternal, err)
}

u := url.URL{
// Ensure the path starts with a slash
if !strings.HasPrefix(p, "/") {
p = "/" + p
}

// Clean the path to handle multiple slashes
p = path.Clean(p)

// Create a new URL with the base URL's scheme and host
u := &url.URL{
Scheme: baseURL.Scheme,
Host: baseURL.Host,
Path: path,
}

if baseURL.Scheme == "" {
// If the base URL has a path, join it with the operation path
if baseURL.Path != "" {
// Clean the base path
basePath := path.Clean(baseURL.Path)
// Join paths and ensure leading slash
u.Path = "/" + strings.TrimPrefix(path.Join(basePath, p), "/")
} else {
u.Path = p
}

// Set default scheme if not present
if u.Scheme == "" {
u.Scheme = "http"
}

Expand Down
100 changes: 100 additions & 0 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package mcp
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"path"
"strings"
"testing"

"github.com/loopwork-ai/emcee/internal"
Expand Down Expand Up @@ -758,3 +761,100 @@ func TestWithAuth(t *testing.T) {
})
}
}

func TestPathJoining(t *testing.T) {
tests := []struct {
name string
baseURL string
path string
expected string
}{
{
name: "simple paths",
baseURL: "https://api.example.com",
path: "/pets",
expected: "https://api.example.com/pets",
},
{
name: "base URL with trailing slash",
baseURL: "https://api.example.com/",
path: "/pets",
expected: "https://api.example.com/pets",
},
{
name: "base URL with path",
baseURL: "https://api.example.com/v1",
path: "/pets",
expected: "https://api.example.com/v1/pets",
},
{
name: "base URL with path and trailing slash",
baseURL: "https://api.example.com/v1/",
path: "/pets",
expected: "https://api.example.com/v1/pets",
},
{
name: "path without leading slash",
baseURL: "https://api.example.com/v1",
path: "pets",
expected: "https://api.example.com/v1/pets",
},
{
name: "multiple path segments",
baseURL: "https://api.example.com/v1",
path: "/pets/dogs",
expected: "https://api.example.com/v1/pets/dogs",
},
{
name: "multiple slashes in path",
baseURL: "https://api.example.com/v1/",
path: "//pets///dogs",
expected: "https://api.example.com/v1/pets/dogs",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a mock HTTP server first
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Parse the request URL and compare with expected
actualPath := path.Clean(r.URL.Path)
expectedPath := path.Clean(tt.path)
if !strings.HasPrefix(expectedPath, "/") {
expectedPath = "/" + expectedPath
}
assert.Equal(t, expectedPath, actualPath)

w.WriteHeader(http.StatusOK)
w.Write([]byte("{}"))
}))
defer ts.Close()

// Create a test spec with the test server URL
spec := fmt.Sprintf(`{
"openapi": "3.0.0",
"servers": [{"url": "%s"}],
"paths": {
"%s": {
"get": {
"operationId": "testOperation"
}
}
}
}`, ts.URL, tt.path)

server, err := NewServer(
WithSpecData([]byte(spec)),
WithClient(ts.Client()),
)
require.NoError(t, err)

// Make a test request
request := jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "testOperation"}`), 1)
response := server.HandleRequest(request)

// Verify the request was successful
assert.Nil(t, response.Error)
})
}
}

0 comments on commit 9676728

Please sign in to comment.