Skip to content

Commit

Permalink
Automatically add Bearer auth scheme if not provided (#14)
Browse files Browse the repository at this point in the history
* Add test coverage for WithAuth option

* Replace WithAuth option with custom HTTP transport that automatically injects headers

* Automatically add Bearer auth scheme if not provided

Extract HTTP transport into internal package

* Use string template instead of concatenation
  • Loading branch information
mattt authored Jan 22, 2025
1 parent 47f55d7 commit c4a70e3
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 24 deletions.
26 changes: 22 additions & 4 deletions cmd/emcee/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"

"github.com/loopwork-ai/emcee/internal"
"github.com/loopwork-ai/emcee/mcp"
)

Expand Down Expand Up @@ -73,14 +74,31 @@ The spec-path-or-url argument can be:
return retryablehttp.DefaultBackoff(min, max, attemptNum, resp)
}
}
client := retryClient.StandardClient()
opts = append(opts, mcp.WithClient(client))

// Set Authentication header if provided
// Set default headers if auth is provided
if auth != "" {
opts = append(opts, mcp.WithAuth(auth))
parts := strings.SplitN(auth, " ", 2)
if len(parts) == 1 {
// Only token provided, add Bearer prefix
logger.Warn("no auth scheme provided, automatically adding 'Bearer' prefix")
auth = "Bearer " + parts[0]
} else if len(parts) == 2 {
// Scheme and token provided, use as-is
auth = fmt.Sprintf("%s %s", parts[0], parts[1])
}

headers := http.Header{}
headers.Add("Authorization", auth)

retryClient.HTTPClient.Transport = &internal.HeaderTransport{
Base: retryClient.HTTPClient.Transport,
Headers: headers,
}
}

client := retryClient.StandardClient()
opts = append(opts, mcp.WithClient(client))

// Read OpenAPI specification data
var rpcInput io.Reader = os.Stdin
var specData []byte
Expand Down
22 changes: 22 additions & 0 deletions internal/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package internal

import "net/http"

// HeaderTransport is a custom RoundTripper that adds default headers to requests
type HeaderTransport struct {
Base http.RoundTripper
Headers http.Header
}

func (t *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, values := range t.Headers {
for _, value := range values {
req.Header.Add(key, value)
}
}
base := t.Base
if base == nil {
base = http.DefaultTransport
}
return base.RoundTrip(req)
}
29 changes: 10 additions & 19 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ import (
// ServerOption configures a Server
type ServerOption func(*Server) error

// WithAuth sets the HTTP Authorization header
func WithAuth(auth string) ServerOption {
return func(s *Server) error {
s.authHeader = auth
return nil
}
}

// WithClient sets the HTTP client
func WithClient(client *http.Client) ServerOption {
return func(s *Server) error {
Expand Down Expand Up @@ -88,19 +80,20 @@ func WithLogger(logger *slog.Logger) ServerOption {

// Server represents an MCP server that processes JSON-RPC requests
type Server struct {
doc libopenapi.Document
model *v3.Document
baseURL string
client *http.Client
info ServerInfo
authHeader string
logger *slog.Logger
doc libopenapi.Document
model *v3.Document
baseURL string
client *http.Client
info ServerInfo
logger *slog.Logger
}

// NewServer creates a new MCP server instance
func NewServer(opts ...ServerOption) (*Server, error) {
s := &Server{
client: http.DefaultClient,
client: &http.Client{
Transport: http.DefaultTransport,
},
}

// Apply options
Expand Down Expand Up @@ -441,9 +434,6 @@ func (s *Server) handleToolsCall(request *ToolCallRequest) (*ToolCallResponse, e
if reqBody != nil {
req.Header.Set("Content-Type", "application/json")
}
if s.authHeader != "" {
req.Header.Set("Authorization", s.authHeader)
}

// Send request
resp, err := s.client.Do(req)
Expand All @@ -468,6 +458,7 @@ func (s *Server) handleToolsCall(request *ToolCallRequest) (*ToolCallResponse, e
contentType := resp.Header.Get("Content-Type")
var content Content

// Create content based on response content type
if strings.HasPrefix(contentType, "image/") {
encoded := base64.StdEncoding.EncodeToString(body)
content = NewImageContent(encoded, contentType, []Role{RoleAssistant}, nil)
Expand Down
79 changes: 78 additions & 1 deletion mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http/httptest"
"testing"

"github.com/loopwork-ai/emcee/internal"
"github.com/loopwork-ai/emcee/jsonrpc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -131,8 +132,17 @@ func setupTestServer(t *testing.T) (*Server, *httptest.Server) {
// Create a small test image
imgData := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} // PNG header

// Track if auth header was checked
authHeaderChecked := false

var ts *httptest.Server
ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify auth header if present
if authHeader := r.Header.Get("Authorization"); authHeader != "" {
assert.Equal(t, "Bearer test-token", authHeader, "Authorization header should match")
authHeaderChecked = true
}

switch r.URL.Path {
case "/openapi.json":
w.Header().Set("Content-Type", "application/json")
Expand All @@ -147,6 +157,11 @@ func setupTestServer(t *testing.T) (*Server, *httptest.Server) {
assert.Equal(t, "5", limit)
assert.Equal(t, "dog", petType)

// For auth test case, verify the auth header was checked
if r.Header.Get("Authorization") != "" {
assert.True(t, authHeaderChecked, "Auth header should have been checked")
}

pets := []map[string]interface{}{
{"id": 1, "name": "Fluffy", "type": "dog"},
{"id": 2, "name": "Rover", "type": "dog"},
Expand Down Expand Up @@ -197,6 +212,12 @@ func setupTestServer(t *testing.T) (*Server, *httptest.Server) {
}
}))

client := ts.Client()
client.Transport = &internal.HeaderTransport{
Base: client.Transport,
Headers: http.Header{"Authorization": []string{"Bearer test-token"}},
}

// Create a server instance with the test server URL and spec
server, err := NewServer(
WithClient(ts.Client()),
Expand Down Expand Up @@ -339,16 +360,21 @@ func TestHandleToolsCall(t *testing.T) {
server, ts := setupTestServer(t)
defer ts.Close()

// Test with auth header
serverWithAuth, _ := setupTestServer(t)

// Create a small test image
imgData := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} // PNG header

tests := []struct {
name string
server *Server
request jsonrpc.Request
validate func(*testing.T, jsonrpc.Response)
}{
{
name: "GET request with query parameters",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "listPets", "arguments": {"limit": 5, "type": "dog"}}`), 1),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
Expand Down Expand Up @@ -391,6 +417,7 @@ func TestHandleToolsCall(t *testing.T) {
},
{
name: "POST request with body parameters",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "createPet", "arguments": {"name": "Whiskers", "age": 5}}`), 2),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
Expand Down Expand Up @@ -428,6 +455,7 @@ func TestHandleToolsCall(t *testing.T) {
},
{
name: "GET image request",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "getPetImage"}`), 3),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
Expand Down Expand Up @@ -464,6 +492,7 @@ func TestHandleToolsCall(t *testing.T) {
},
{
name: "Request with invalid operationId",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "nonexistentOperation"}`), 4),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
Expand All @@ -474,6 +503,7 @@ func TestHandleToolsCall(t *testing.T) {
},
{
name: "GET request with URL escaped parameters",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "getPet", "arguments": {"petId": "special pet"}}`), 5),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
Expand Down Expand Up @@ -504,11 +534,58 @@ func TestHandleToolsCall(t *testing.T) {
assert.Equal(t, "Special Pet", pet["name"])
},
},
{
name: "Request with auth header",
server: serverWithAuth,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "listPets", "arguments": {"limit": 5, "type": "dog"}}`), 6),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
assert.Equal(t, 6, response.ID.Value())
assert.Nil(t, response.Error)

var result ToolCallResponse
resultBytes, err := json.Marshal(response.Result)
require.NoError(t, err)
err = json.Unmarshal(resultBytes, &result)
require.NoError(t, err)

assert.Len(t, result.Content, 1)
assert.False(t, result.IsError)

// Verify the response content
content := result.Content[0]
assert.Equal(t, "text", content.Type)
assert.NotNil(t, content.Annotations)
assert.Contains(t, content.Annotations.Audience, RoleAssistant)

var textContent Content
contentBytes, err := json.Marshal(content)
assert.NoError(t, err)
err = json.Unmarshal(contentBytes, &textContent)
assert.NoError(t, err)

var pets []interface{}
err = json.Unmarshal([]byte(textContent.Text), &pets)
assert.NoError(t, err)
assert.Len(t, pets, 2)

// Verify the returned pets
for _, pet := range pets {
petMap := pet.(map[string]interface{})
assert.Equal(t, "dog", petMap["type"])
}
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
response := server.HandleRequest(tt.request)
var response jsonrpc.Response
if tt.server != nil {
response = tt.server.HandleRequest(tt.request)
} else {
response = server.HandleRequest(tt.request)
}
tt.validate(t, response)
})
}
Expand Down

0 comments on commit c4a70e3

Please sign in to comment.